1"""Train 1-layer GRU (compatible with C# inference) on ALL words, RTX 4080.""" 2import torch,torch.nn as nn,struct,random,os,sys,time,numpy as np 3ALPHA=35;HIDDEN=256;EMBED_DIM=64;MAX_LEN=25;EPOCHS=80;BATCH_SIZE=1024;LR=0.001 4DATA_DIR=os.path.join(os.path.dirname(os.path.abspath(__file__)),'..','Data') 5SAVE_PATH=os.path.join(DATA_DIR,'seq2spell_gpu.bin') 6def p(*a,**k):print(*a,**k);sys.stdout.flush() 7def ci(c): 8 if 'а'<=c<='я':return ord(c)-ord('а') 9 if c=='ё':return 32 10 if c==' ':return 33 11 return ALPHA-1 12def ic(i): 13 if 0<=i<=31:return chr(ord('а')+i) 14 if i==32:return 'ё' 15 if i==33:return ' ' 16 return '' 17def corrupt(w,rng): 18 w=list(w) 19 if len(w)<3:return ''.join(w) 20 for _ in range(rng.randint(1,3)): 21 op=rng.randint(0,3) 22 if op==0 and len(w)>1:i=rng.randint(0,len(w)-2);w[i],w[i+1]=w[i+1],w[i] 23 elif op==1:w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31)) 24 elif op==2 and len(w)>2:del w[rng.randint(0,len(w)-1)] 25 elif op==3:w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31))) 26 return ''.join(w) 27def enc(word): 28 ids=[ci(c) for c in word[:MAX_LEN]];ids.append(ALPHA-1);return ids 29 30class M(nn.Module): 31 def __init__(self): 32 super().__init__() 33 self.emb=nn.Embedding(ALPHA,EMBED_DIM) 34 # SINGLE layer GRU — compatible with C# GruSpellNet 35 self.enc=nn.GRU(EMBED_DIM,HIDDEN,num_layers=1,batch_first=True) 36 self.dec=nn.GRU(EMBED_DIM,HIDDEN,num_layers=1,batch_first=True) 37 self.out=nn.Linear(HIDDEN,ALPHA) 38 def forward(self,s,t): 39 _,h=self.enc(self.emb(s));return self.out(self.dec(self.emb(t),h)[0]) 40 41p("Loading ALL words...") 42words=[] 43with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f: 44 for l in f: 45 w=l.strip().lower() 46 if 2<=len(w)<=20:words.append(w) 47p(f"Loaded {len(words)} words") 48dev=torch.device('cuda') 49p(f"Device: {dev} ({torch.cuda.get_device_name(0)})") 50model=M().to(dev) 51params=sum(pp.numel() for pp in model.parameters()) 52p(f"Model: 1-layer GRU, hidden={HIDDEN}, embed={EMBED_DIM}, params={params:,}") 53opt=torch.optim.Adam(model.parameters(),lr=LR) 54sched=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=EPOCHS) 55crit=nn.CrossEntropyLoss() 56rng=random.Random(42) 57p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}") 58t0=time.time() 59for ep in range(EPOCHS): 60 model.train();tl=0;cr=0;tot=0;nb=300 61 for _ in range(nb): 62 sb,ti,to_=[],[],[] 63 for _ in range(BATCH_SIZE): 64 tgt=words[rng.randint(0,len(words)-1)] 65 inp=corrupt(tgt,rng) if rng.random()<0.7 else tgt 66 s=enc(inp);t=enc(tgt) 67 while len(s)<MAX_LEN+1:s.append(ALPHA-1) 68 while len(t)<MAX_LEN+1:t.append(ALPHA-1) 69 sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1]) 70 src=torch.tensor(sb,device=dev);tin=torch.tensor(ti,device=dev);tout=torch.tensor(to_,device=dev) 71 logits=model(src,tin);loss=crit(logits.view(-1,ALPHA),tout.view(-1)) 72 opt.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);opt.step() 73 tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1 74 cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item() 75 sched.step();acc=cr/tot*100 76 if ep%5==0 or ep==EPOCHS-1: 77 p(f" Epoch {ep}: loss={tl/nb:.3f} acc={acc:.1f}% [{time.time()-t0:.0f}s]") 78 79p("Exporting 1-layer weights for C#...") 80model.eval() 81with open(SAVE_PATH,'wb') as f: 82 f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN)) 83 e=model.emb.weight.detach().cpu().numpy() 84 for i in range(ALPHA): 85 for j in range(EMBED_DIM):f.write(struct.pack('f',float(e[i][j]))) 86 # 1-layer: all params have _l0 suffix 87 for n,pa in model.enc.named_parameters(): 88 d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d))) 89 for v in d:f.write(struct.pack('f',float(v))) 90 for n,pa in model.dec.named_parameters(): 91 d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d))) 92 for v in d:f.write(struct.pack('f',float(v))) 93 w=model.out.weight.detach().cpu().numpy();b=model.out.bias.detach().cpu().numpy() 94 for i in range(ALPHA): 95 for j in range(HIDDEN):f.write(struct.pack('f',float(w[i][j]))) 96 for i in range(ALPHA):f.write(struct.pack('f',float(b[i]))) 97p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB") 98 99p("\n=== Test (greedy decode) ===") 100tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","мсыли","бсытро","пчеатаю","клвиатуре","совршенно","написатьб","выгворится","накпело","хчоется","тихналогия"] 101for t in tests: 102 src=torch.tensor([enc(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(enc(t)))],device=dev) 103 _,h=model.enc(model.emb(src));res=[];ic2=ALPHA-1 104 for _ in range(MAX_LEN): 105 it=torch.tensor([[ic2]],device=dev);o,h=model.dec(model.emb(it),h) 106 pr=model.out(o.squeeze(1)).argmax(-1).item() 107 if pr==ALPHA-1:break 108 ch=ic(pr) 109 if ch:res.append(ch) 110 ic2=pr 111 p(f" {t} -> {''.join(res)}") 112p("Done!")