windowcapture
исходный код / Tools/export_model.py

export_model.py

117 строк · 5,077 байт · модуль Tools
  1"""Re-export: retrain briefly (1 epoch) then export with numpy available."""
  2import torch, torch.nn as nn, struct, random, os, sys, time
  3
  4ALPHA=35; HIDDEN=192; EMBED_DIM=64; MAX_LEN=25; BATCH_SIZE=512
  5DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'Data')
  6SAVE_PATH = os.path.join(DATA_DIR, 'seq2spell_gpu.bin')
  7
  8def p(*a,**k): print(*a,**k); sys.stdout.flush()
  9def char_to_idx(c):
 10    if 'а' <= c <= 'я': return ord(c) - ord('а')
 11    if c == 'ё': return 32
 12    if c == ' ': return 33
 13    return ALPHA - 1
 14def idx_to_char(i):
 15    if 0 <= i <= 31: return chr(ord('а') + i)
 16    if i == 32: return 'ё'
 17    if i == 33: return ' '
 18    return ''
 19def corrupt(word, rng):
 20    w = list(word)
 21    if len(w) < 3: return word
 22    for _ in range(rng.randint(1,3)):
 23        op = rng.randint(0,3)
 24        if op==0 and len(w)>1: i=rng.randint(0,len(w)-2); w[i],w[i+1]=w[i+1],w[i]
 25        elif op==1: w[rng.randint(0,len(w)-1)]=chr(ord('а')+rng.randint(0,31))
 26        elif op==2 and len(w)>2: del w[rng.randint(0,len(w)-1)]
 27        elif op==3: w.insert(rng.randint(0,len(w)),chr(ord('а')+rng.randint(0,31)))
 28    return ''.join(w)
 29def encode(word):
 30    ids = [char_to_idx(c) for c in word[:MAX_LEN]]; ids.append(ALPHA-1); return ids
 31
 32class Model(nn.Module):
 33    def __init__(self):
 34        super().__init__()
 35        self.embed = nn.Embedding(ALPHA, EMBED_DIM)
 36        self.encoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1)
 37        self.decoder = nn.GRU(EMBED_DIM, HIDDEN, num_layers=2, batch_first=True, dropout=0.1)
 38        self.output = nn.Linear(HIDDEN, ALPHA)
 39    def forward(self, src, tgt):
 40        _, h = self.encoder(self.embed(src))
 41        return self.output(self.decoder(self.embed(tgt), h)[0])
 42
 43p("Loading dict...")
 44words = []
 45with open(os.path.join(DATA_DIR,'dict_ru.txt'),'r',encoding='utf-8') as f:
 46    for l in f:
 47        w=l.strip().lower()
 48        if 2<=len(w)<=20: words.append(w)
 49        if len(words)>=200000: break
 50p(f"{len(words)} words")
 51
 52device = torch.device('cuda')
 53model = Model().to(device)
 54optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 55criterion = nn.CrossEntropyLoss()
 56rng = random.Random(42)
 57
 58# Quick train 10 epochs
 59p("Quick training 10 epochs on GPU...")
 60for epoch in range(10):
 61    model.train(); tl=0; cr=0; tot=0
 62    for _ in range(200):
 63        sb,ti,to_=[],[],[]
 64        for _ in range(BATCH_SIZE):
 65            tgt=words[rng.randint(0,len(words)-1)]
 66            inp=corrupt(tgt,rng) if rng.random()<0.7 else tgt
 67            s=encode(inp);t=encode(tgt)
 68            while len(s)<MAX_LEN+1:s.append(ALPHA-1)
 69            while len(t)<MAX_LEN+1:t.append(ALPHA-1)
 70            sb.append(s[:MAX_LEN+1]);ti.append([ALPHA-1]+t[:MAX_LEN]);to_.append(t[:MAX_LEN+1])
 71        src=torch.tensor(sb,device=device);tin=torch.tensor(ti,device=device);tout=torch.tensor(to_,device=device)
 72        logits=model(src,tin);loss=criterion(logits.view(-1,ALPHA),tout.view(-1))
 73        optimizer.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(model.parameters(),1.0);optimizer.step()
 74        tl+=loss.item();preds=logits.argmax(-1);mask=tout!=ALPHA-1;cr+=(preds==tout)[mask].sum().item();tot+=mask.sum().item()
 75    p(f"  Epoch {epoch}: acc={cr/tot*100:.1f}%")
 76
 77p("Exporting...")
 78model.eval()
 79import numpy as np
 80with open(SAVE_PATH,'wb') as f:
 81    f.write(struct.pack('iiii',ALPHA,EMBED_DIM,HIDDEN,MAX_LEN))
 82    emb=model.embed.weight.detach().cpu().numpy()
 83    for i in range(ALPHA):
 84        for j in range(EMBED_DIM): f.write(struct.pack('f',float(emb[i][j])))
 85    for name,param in model.encoder.named_parameters():
 86        if '_l0' in name:
 87            data=param.detach().cpu().numpy().flatten()
 88            f.write(struct.pack('i',len(data)))
 89            for v in data: f.write(struct.pack('f',float(v)))
 90    for name,param in model.decoder.named_parameters():
 91        if '_l0' in name:
 92            data=param.detach().cpu().numpy().flatten()
 93            f.write(struct.pack('i',len(data)))
 94            for v in data: f.write(struct.pack('f',float(v)))
 95    w=model.output.weight.detach().cpu().numpy();b=model.output.bias.detach().cpu().numpy()
 96    for i in range(ALPHA):
 97        for j in range(HIDDEN): f.write(struct.pack('f',float(w[i][j])))
 98    for i in range(ALPHA): f.write(struct.pack('f',float(b[i])))
 99p(f"Saved: {os.path.getsize(SAVE_PATH)//1024}KB")
100
101# Test
102p("\n=== Test ===")
103tests=["привет","прввет","компуктер","здраствуте","пшоел","кароче","тихналогия","мсыли","бсытро","написатьб"]
104for t in tests:
105    src=torch.tensor([encode(t)[:MAX_LEN+1]+[ALPHA-1]*(MAX_LEN+1-len(encode(t)))],device=device)
106    _,h=model.encoder(model.embed(src))
107    result=[];ic=ALPHA-1
108    for _ in range(MAX_LEN):
109        it=torch.tensor([[ic]],device=device)
110        out,h=model.decoder(model.embed(it),h)
111        pred=model.output(out.squeeze(1)).argmax(-1).item()
112        if pred==ALPHA-1:break
113        ch=idx_to_char(pred)
114        if ch:result.append(ch)
115        ic=pred
116    p(f"  {t} -> {''.join(result)}")
117p("Done!")