1"""Train GRU Seq2Seq on ALL 1.5M words, 50 epochs, RTX 4080.""" 2import torch, torch.nn as nn, struct, random, os, sys, time, numpy as np 3 4ALPHA=35; HIDDEN=256; EMBED_DIM=64; MAX_LEN=25 5EPOCHS=50; BATCH_SIZE=1024; LR=0.001 6 7DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data') 8SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin') 9 10def p(*a,**k): print(*a,**k); sys.stdout.flush() 11 12def ci(c): 13 if 'а'<=c<='я': return ord(c)-ord('а') 14 if c=='ё': return 32 15 if c==' ': return 33 16 return ALPHA-1 17 18def ic(i): 19 if 0<=i<=31: return chr(ord('а')+i) 20 if i==32: return 'ё' 21 if i==33: return ' ' 22 return '' 23 24def corrupt(w, rng): 25 w=list(w) 26 if len(w)<3: return ''.join(w) 27 for _ in range(rng.randint(1,3)): 28 op=rng.randint(0,3) 29 if op==0 and len(w)>1: i=rng.randint(0,len(w)-2);w[i],w[i+1]=w[i+1],w[i] 30 elif op==1: w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31)) 31 elif op==2 and len(w)>2: del w[rng.randint(0,len(w)-1)] 32 elif op==3: w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31))) 33 return ''.join(w) 34 35def enc(word): 36 ids=[ci(c) for c in word[:MAX_LEN]]; ids.append(ALPHA-1); return ids 37 38class M(nn.Module): 39 def __init__(self): 40 super().__init__() 41 self.emb=nn.Embedding(ALPHA,EMBED_DIM) 42 self.enc=nn.GRU(EMBED_DIM,HIDDEN,num_layers=2,batch_first=True,dropout=0.15) 43 self.dec=nn.GRU(EMBED_DIM,HIDDEN,num_layers=2,batch_first=True,dropout=0.15) 44 self.out=nn.Linear(HIDDEN,ALPHA) 45 def forward(self,s,t): 46 _,h=self.enc(self.emb(s)) 47 return self.out(self.dec(self.emb(t),h)[0]) 48 49p("Loading ALL words...") 50words=[] 51with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f: 52 for l in f: 53 w=l.strip().lower() 54 if 2<=len(w)<=20: words.append(w) 55p(f"Loaded {len(words)} words (ALL)") 56 57dev=torch.device('cuda') 58p(f"Device: {dev} ({torch.cuda.get_device_name(0)})") 59 60model=M().to(dev) 61params=sum(p2.numel() for p2 in model.parameters()) 62p(f"Model: hidden={HIDDEN}, embed={EMBED_DIM}, layers=2, params={params:,}") 63opt=torch.optim.Adam(model.parameters(),lr=LR) 64sched=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=EPOCHS) 65crit=nn.CrossEntropyLoss() 66rng=random.Random(42) 67 68p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}") 69t0=time.time() 70 71for ep in range(EPOCHS): 72 model.train(); tl=0;cr=0;tot=0 73 nb=300 74 for _ in range(nb): 75 sb,ti,to_=[],[],[] 76 for _ in range(BATCH_SIZE): 77 tgt=words[rng.randint(0,len(words)-1)] 78 inp=corrupt(tgt,rng) if rng.random()<0.7 else tgt 79 s=enc(inp);t=enc(tgt) 80 while len(s)<MAX_LEN+1:s.append(ALPHA-1) 81 while len(t)<MAX_LEN+1:t.append(ALPHA-1) 82 sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1]) 83 src=torch.tensor(sb,device=dev);tin=torch.tensor(ti,device=dev);tout=torch.tensor(to_,device=dev) 84 logits=model(src,tin);loss=crit(logits.view(-1,ALPHA),tout.view(-1)) 85 opt.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);opt.step() 86 tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1 87 cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item() 88 sched.step() 89 acc=cr/tot*100 90 p(f" Epoch {ep}: loss={tl/nb:.3f} acc={acc:.1f}% lr={sched.get_last_lr()[0]:.6f} [{time.time()-t0:.0f}s]") 91 92p("Exporting...") 93model.eval() 94with open(SAVE_PATH,'wb') as f: 95 f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN)) 96 e=model.emb.weight.detach().cpu().numpy() 97 for i in range(ALPHA): 98 for j in range(EMBED_DIM):f.write(struct.pack('f',float(e[i][j]))) 99 for n,pa in model.enc.named_parameters(): 100 if '_l0' in n: 101 d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d))) 102 for v in d:f.write(struct.pack('f',float(v))) 103 for n,pa in model.dec.named_parameters(): 104 if '_l0' in n: 105 d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d))) 106 for v in d:f.write(struct.pack('f',float(v))) 107 w=model.out.weight.detach().cpu().numpy();b=model.out.bias.detach().cpu().numpy() 108 for i in range(ALPHA): 109 for j in range(HIDDEN):f.write(struct.pack('f',float(w[i][j]))) 110 for i in range(ALPHA):f.write(struct.pack('f',float(b[i]))) 111p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB") 112 113p("\n=== Test ===") 114tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","мсыли","бсытро","написатьб","пчеатаю","клвиатуре","совршенно"] 115for t in tests: 116 src=torch.tensor([enc(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(enc(t)))],device=dev) 117 _,h=model.enc(model.emb(src)) 118 res=[];ic2=ALPHA-1 119 for _ in range(MAX_LEN): 120 it=torch.tensor([[ic2]],device=dev) 121 o,h=model.dec(model.emb(it),h) 122 pr=model.out(o.squeeze(1)).argmax(-1).item() 123 if pr==ALPHA-1:break 124 ch=ic(pr) 125 if ch:res.append(ch) 126 ic2=pr 127 p(f" {t} -> {''.join(res)}") 128p("Done!")