windowcapture
исходный код / Tools/train_full.py

train_full.py

128 строк · 5,150 байт · модуль Tools
  1"""Train GRU Seq2Seq on ALL 1.5M words, 50 epochs, RTX 4080."""
  2import torch, torch.nn as nn, struct, random, os, sys, time, numpy as np
  3
  4ALPHA=35; HIDDEN=256; EMBED_DIM=64; MAX_LEN=25
  5EPOCHS=50; BATCH_SIZE=1024; LR=0.001
  6
  7DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data')
  8SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin')
  9
 10def p(*a,**k): print(*a,**k); sys.stdout.flush()
 11
 12def ci(c):
 13    if 'а'<=c<='я': return ord(c)-ord('а')
 14    if c=='ё': return 32
 15    if c==' ': return 33
 16    return ALPHA-1
 17
 18def ic(i):
 19    if 0<=i<=31: return chr(ord('а')+i)
 20    if i==32: return 'ё'
 21    if i==33: return ' '
 22    return ''
 23
 24def corrupt(w, rng):
 25    w=list(w)
 26    if len(w)<3: return ''.join(w)
 27    for _ in range(rng.randint(1,3)):
 28        op=rng.randint(0,3)
 29        if op==0 and len(w)>1: i=rng.randint(0,len(w)-2);w[i],w[i+1]=w[i+1],w[i]
 30        elif op==1: w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31))
 31        elif op==2 and len(w)>2: del w[rng.randint(0,len(w)-1)]
 32        elif op==3: w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31)))
 33    return ''.join(w)
 34
 35def enc(word):
 36    ids=[ci(c) for c in word[:MAX_LEN]]; ids.append(ALPHA-1); return ids
 37
 38class M(nn.Module):
 39    def __init__(self):
 40        super().__init__()
 41        self.emb=nn.Embedding(ALPHA,EMBED_DIM)
 42        self.enc=nn.GRU(EMBED_DIM,HIDDEN,num_layers=2,batch_first=True,dropout=0.15)
 43        self.dec=nn.GRU(EMBED_DIM,HIDDEN,num_layers=2,batch_first=True,dropout=0.15)
 44        self.out=nn.Linear(HIDDEN,ALPHA)
 45    def forward(self,s,t):
 46        _,h=self.enc(self.emb(s))
 47        return self.out(self.dec(self.emb(t),h)[0])
 48
 49p("Loading ALL words...")
 50words=[]
 51with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f:
 52    for l in f:
 53        w=l.strip().lower()
 54        if 2<=len(w)<=20: words.append(w)
 55p(f"Loaded {len(words)} words (ALL)")
 56
 57dev=torch.device('cuda')
 58p(f"Device: {dev} ({torch.cuda.get_device_name(0)})")
 59
 60model=M().to(dev)
 61params=sum(p2.numel() for p2 in model.parameters())
 62p(f"Model: hidden={HIDDEN}, embed={EMBED_DIM}, layers=2, params={params:,}")
 63opt=torch.optim.Adam(model.parameters(),lr=LR)
 64sched=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=EPOCHS)
 65crit=nn.CrossEntropyLoss()
 66rng=random.Random(42)
 67
 68p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}")
 69t0=time.time()
 70
 71for ep in range(EPOCHS):
 72    model.train(); tl=0;cr=0;tot=0
 73    nb=300
 74    for _ in range(nb):
 75        sb,ti,to_=[],[],[]
 76        for _ in range(BATCH_SIZE):
 77            tgt=words[rng.randint(0,len(words)-1)]
 78            inp=corrupt(tgt,rng) if rng.random()<0.7 else tgt
 79            s=enc(inp);t=enc(tgt)
 80            while len(s)<MAX_LEN+1:s.append(ALPHA-1)
 81            while len(t)<MAX_LEN+1:t.append(ALPHA-1)
 82            sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1])
 83        src=torch.tensor(sb,device=dev);tin=torch.tensor(ti,device=dev);tout=torch.tensor(to_,device=dev)
 84        logits=model(src,tin);loss=crit(logits.view(-1,ALPHA),tout.view(-1))
 85        opt.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);opt.step()
 86        tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1
 87        cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item()
 88    sched.step()
 89    acc=cr/tot*100
 90    p(f"  Epoch {ep}: loss={tl/nb:.3f} acc={acc:.1f}% lr={sched.get_last_lr()[0]:.6f} [{time.time()-t0:.0f}s]")
 91
 92p("Exporting...")
 93model.eval()
 94with open(SAVE_PATH,'wb') as f:
 95    f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN))
 96    e=model.emb.weight.detach().cpu().numpy()
 97    for i in range(ALPHA):
 98        for j in range(EMBED_DIM):f.write(struct.pack('f',float(e[i][j])))
 99    for n,pa in model.enc.named_parameters():
100        if '_l0' in n:
101            d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d)))
102            for v in d:f.write(struct.pack('f',float(v)))
103    for n,pa in model.dec.named_parameters():
104        if '_l0' in n:
105            d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d)))
106            for v in d:f.write(struct.pack('f',float(v)))
107    w=model.out.weight.detach().cpu().numpy();b=model.out.bias.detach().cpu().numpy()
108    for i in range(ALPHA):
109        for j in range(HIDDEN):f.write(struct.pack('f',float(w[i][j])))
110    for i in range(ALPHA):f.write(struct.pack('f',float(b[i])))
111p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB")
112
113p("\n=== Test ===")
114tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","мсыли","бсытро","написатьб","пчеатаю","клвиатуре","совршенно"]
115for t in tests:
116    src=torch.tensor([enc(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(enc(t)))],device=dev)
117    _,h=model.enc(model.emb(src))
118    res=[];ic2=ALPHA-1
119    for _ in range(MAX_LEN):
120        it=torch.tensor([[ic2]],device=dev)
121        o,h=model.dec(model.emb(it),h)
122        pr=model.out(o.squeeze(1)).argmax(-1).item()
123        if pr==ALPHA-1:break
124        ch=ic(pr)
125        if ch:res.append(ch)
126        ic2=pr
127    p(f"  {t} -> {''.join(res)}")
128p("Done!")