-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathqe.py
More file actions
142 lines (111 loc) · 4.96 KB
/
qe.py
File metadata and controls
142 lines (111 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#-*- coding:utf8 -*-
'''
Created on Jun 16, 2017
@author: czm
'''
from RNN import data_iter
from nematus import nmt
from nematus import data_iterator
import theano
from theano import tensor
from nematus.util import load_config
from nematus.theano_util import load_params,init_theano_params
from RNN import rnn, rnn_trg, rnn_stack_pro
import numpy
def qe_rnn(model='RNN_model/wmt17.en-de.npz',
datasets=['test/test16.bpe.en',
'test/test16.bpe.de'],
save_file='test16.hter.pred'
):
options = load_config(model)
params = rnn.init_params(options)
params = load_params(model, params)
tparams = init_theano_params(params)
trng,use_noise,x,x_mask,y,y_mask,\
hter,y_pred,cost = rnn.build_model(tparams,options)
inps = [x,x_mask,y,y_mask]
f_pred = theano.function(inps,y_pred,profile=False)
test = data_iterator.TextIterator(datasets[0],datasets[1],
options['dictionaries'][0], options['dictionaries'][1],
n_words_source=options['n_words_src'], n_words_target=options['n_words_tgt'],
batch_size=options['valid_batch_size'],
maxlen=options['maxlen'],
sort_by_length=False)
res = []
n_samples = 0
for x,y in test:
x, x_mask, y, y_mask = nmt.prepare_data(x, y, maxlen=options['maxlen'],
n_words_src=options['n_words_src'],
n_words=options['n_words_tgt'])
res.extend(list(f_pred(x,x_mask,y,y_mask).flatten()))
n_samples += x.shape[1]
print 'processed:',n_samples,'samples'
with open('qe/'+save_file,'w') as fp:
for hh in res:
fp.writelines(str(hh)+'\n')
def qe_rnn_stack_pro(model='stack_model/stack.en-de.npz',
datasets=['test/test16.bpe.en',
'test/test16.bpe.de'],
save_file='test16.hter.pred'
):
options = load_config(model)
params = rnn_stack_pro.init_params(options) # ******
params = load_params(model, params)
tparams = init_theano_params(params)
trng,use_noise,x,x_mask,y,y_mask,\
hter,y_pred,cost = rnn_stack_pro.build_model(tparams,options) # *****
inps = [x,x_mask,y,y_mask]
f_pred = theano.function(inps,y_pred,profile=False)
test = data_iterator.TextIterator(datasets[0],datasets[1],
options['dictionaries'][0], options['dictionaries'][1],
n_words_source=options['n_words_src'], n_words_target=options['n_words_tgt'],
batch_size=options['valid_batch_size'],
maxlen=options['maxlen'],
sort_by_length=False)
res = []
n_samples = 0
for x,y in test:
x, x_mask, y, y_mask = nmt.prepare_data(x, y, maxlen=options['maxlen'],
n_words_src=options['n_words_src'],
n_words=options['n_words_tgt'])
res.extend(list(f_pred(x,x_mask,y,y_mask).flatten()))
n_samples += x.shape[1]
print 'processed:',n_samples,'samples'
with open('qe/'+save_file,'w') as fp:
for hh in res:
fp.writelines(str(hh)+'\n')
def qe_rnn_trg(model='RNN_trg_model/wmt17.de-en.npz',
datasets=['test/test16.bpe.src',
'test/test16.bpe.mt'],
save_file='test16.hter.pred'
):
options = load_config(model)
#-------------------
params = rnn_trg.init_params(options) # 修改此处
params = load_params(model, params)
tparams = init_theano_params(params)
#-------------------
trng,use_noise,x,x_mask,y,y_mask,\
hter,y_pred,cost = rnn_trg.build_model(tparams,options) # 修改此处
inps = [x,x_mask,y,y_mask]
f_pred = theano.function(inps,y_pred,profile=False)
test = data_iterator.TextIterator(datasets[0],datasets[1],
options['dictionaries'][0], options['dictionaries'][1],
n_words_source=options['n_words_src'], n_words_target=options['n_words_tgt'],
batch_size=options['valid_batch_size'],
maxlen=options['maxlen'],
sort_by_length=False)
res = []
n_samples = 0
for x,y in test:
x, x_mask, y, y_mask = nmt.prepare_data(x, y, maxlen=options['maxlen'],
n_words_src=options['n_words_src'],
n_words=options['n_words_tgt'])
res.extend(list(f_pred(x,x_mask,y,y_mask).flatten()))
n_samples += x.shape[1]
print 'processed:',n_samples,'samples'
with open('qe/'+save_file,'w') as fp:
for hh in res:
fp.writelines(str(hh)+'\n')
if __name__ == '__main__':
pass