windowcapture
исходный код / Tools/train_aggressive.py

train_aggressive.py

108 строк · 5,440 байт · модуль Tools
  1"""Aggressive GRU training: hidden=384, batch=2048, 100 epochs, maximize GPU usage."""
  2import torch,torch.nn as nn,struct,random,os,sys,time,numpy as np
  3ALPHA=35;HIDDEN=384;EMBED_DIM=96;MAX_LEN=25;EPOCHS=100;BATCH_SIZE=2048;LR=0.002
  4DATA_DIR=os.path.join(os.path.dirname(os.path.abspath(__file__)),'..','Data')
  5SAVE_PATH=os.path.join(DATA_DIR,'seq2spell_gpu.bin')
  6def p(*a,**k):print(*a,**k);sys.stdout.flush()
  7def ci(c):
  8    if 'а'<=c<='я':return ord(c)-ord('а')
  9    if c=='ё':return 32
 10    if c==' ':return 33
 11    return ALPHA-1
 12def ic(i):
 13    if 0<=i<=31:return chr(ord('а')+i)
 14    if i==32:return 'ё'
 15    if i==33:return ' '
 16    return ''
 17def corrupt(w,rng):
 18    w=list(w)
 19    if len(w)<3:return ''.join(w)
 20    for _ in range(rng.randint(1,4)): # up to 4 corruptions!
 21        op=rng.randint(0,3)
 22        if op==0 and len(w)>1:i=rng.randint(0,len(w)-2);w[i],w[i+1]=w[i+1],w[i]
 23        elif op==1:w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31))
 24        elif op==2 and len(w)>2:del w[rng.randint(0,len(w)-1)]
 25        elif op==3:w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31)))
 26    return ''.join(w)
 27def enc(word):
 28    ids=[ci(c) for c in word[:MAX_LEN]];ids.append(ALPHA-1);return ids
 29class M(nn.Module):
 30    def __init__(self):
 31        super().__init__()
 32        self.emb=nn.Embedding(ALPHA,EMBED_DIM)
 33        self.enc=nn.GRU(EMBED_DIM,HIDDEN,num_layers=3,batch_first=True,dropout=0.2)
 34        self.dec=nn.GRU(EMBED_DIM,HIDDEN,num_layers=3,batch_first=True,dropout=0.2)
 35        self.out=nn.Linear(HIDDEN,ALPHA)
 36    def forward(self,s,t):
 37        _,h=self.enc(self.emb(s));return self.out(self.dec(self.emb(t),h)[0])
 38p("Loading ALL words...")
 39words=[]
 40with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f:
 41    for l in f:
 42        w=l.strip().lower()
 43        if 2<=len(w)<=20:words.append(w)
 44p(f"Loaded {len(words)} words")
 45dev=torch.device('cuda')
 46p(f"Device: {dev} ({torch.cuda.get_device_name(0)})")
 47p(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f}GB")
 48model=M().to(dev)
 49params=sum(pp.numel() for pp in model.parameters())
 50p(f"Model: hidden={HIDDEN}, embed={EMBED_DIM}, layers=3, params={params:,}")
 51opt=torch.optim.AdamW(model.parameters(),lr=LR,weight_decay=0.01)
 52sched=torch.optim.lr_scheduler.OneCycleLR(opt,max_lr=LR,total_steps=EPOCHS*400,pct_start=0.1)
 53crit=nn.CrossEntropyLoss()
 54rng=random.Random(42)
 55p(f"Training {EPOCHS} epochs, batch={BATCH_SIZE}, OneCycleLR")
 56t0=time.time();best_acc=0
 57for ep in range(EPOCHS):
 58    model.train();tl=0;cr=0;tot=0;nb=400
 59    for _ in range(nb):
 60        sb,ti,to_=[],[],[]
 61        for _ in range(BATCH_SIZE):
 62            tgt=words[rng.randint(0,len(words)-1)]
 63            inp=corrupt(tgt,rng) if rng.random()<0.75 else tgt
 64            s=enc(inp);t=enc(tgt)
 65            while len(s)<MAX_LEN+1:s.append(ALPHA-1)
 66            while len(t)<MAX_LEN+1:t.append(ALPHA-1)
 67            sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1])
 68        src=torch.tensor(sb,device=dev);tin=torch.tensor(ti,device=dev);tout=torch.tensor(to_,device=dev)
 69        logits=model(src,tin);loss=crit(logits.view(-1,ALPHA),tout.view(-1))
 70        opt.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);opt.step();sched.step()
 71        tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1
 72        cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item()
 73    acc=cr/tot*100
 74    if acc>best_acc:best_acc=acc
 75    if ep%5==0 or ep==EPOCHS-1:
 76        mem=torch.cuda.max_memory_allocated()/1024**3
 77        p(f"  Epoch {ep}: loss={tl/nb:.3f} acc={acc:.1f}% best={best_acc:.1f}% mem={mem:.1f}GB [{time.time()-t0:.0f}s]")
 78p(f"\nBest accuracy: {best_acc:.1f}%")
 79p("Exporting...")
 80model.eval()
 81with open(SAVE_PATH,'wb') as f:
 82    f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN))
 83    e=model.emb.weight.detach().cpu().numpy()
 84    for i in range(ALPHA):
 85        for j in range(EMBED_DIM):f.write(struct.pack('f',float(e[i][j])))
 86    for n,pa in model.enc.named_parameters():
 87        if '_l0' in n:d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d)));[f.write(struct.pack('f',float(v))) for v in d]
 88    for n,pa in model.dec.named_parameters():
 89        if '_l0' in n:d=pa.detach().cpu().numpy().flatten();f.write(struct.pack('i',len(d)));[f.write(struct.pack('f',float(v))) for v in d]
 90    w=model.out.weight.detach().cpu().numpy();b=model.out.bias.detach().cpu().numpy()
 91    for i in range(ALPHA):
 92        for j in range(HIDDEN):f.write(struct.pack('f',float(w[i][j])))
 93    for i in range(ALPHA):f.write(struct.pack('f',float(b[i])))
 94p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB")
 95p("\n=== Test ===")
 96tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","мсыли","бсытро","пчеатаю","клвиатуре","совршенно","написатьб","выгворится","накпело","хчоется"]
 97for t in tests:
 98    src=torch.tensor([enc(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(enc(t)))],device=dev)
 99    _,h=model.enc(model.emb(src));res=[];ic2=ALPHA-1
100    for _ in range(MAX_LEN):
101        it=torch.tensor([[ic2]],device=dev);o,h=model.dec(model.emb(it),h)
102        pr=model.out(o.squeeze(1)).argmax(-1).item()
103        if pr==ALPHA-1:break
104        ch=ic(pr)
105        if ch:res.append(ch)
106        ic2=pr
107    p(f"  {t} -> {''.join(res)}")
108p("Done!")