-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseq2seq_loss.py
More file actions
44 lines (37 loc) · 2.25 KB
/
seq2seq_loss.py
File metadata and controls
44 lines (37 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
save_path = 'checkpoints/dev'
(source_int_text, target_int_text), (source_vocab_to_int, target_vocab_to_int), _ = load_preprocess()
max_target_sentence_length = max([len(sentence) for sentence in source_int_text])
train_graph = tf.Graph()
with train_graph.as_default():
input_data, targets, target_sequence_length, max_target_sequence_length = enc_dec_model_inputs()
lr, keep_prob = hyperparam_inputs()
train_logits, inference_logits = seq2seq_model(tf.reverse(input_data, [-1]),
targets,
keep_prob,
batch_size,
target_sequence_length,
max_target_sequence_length,
len(source_vocab_to_int),
len(target_vocab_to_int),
encoding_embedding_size,
decoding_embedding_size,
rnn_size,
num_layers,
target_vocab_to_int)
training_logits = tf.identity(train_logits.rnn_output, name='logits')
inference_logits = tf.identity(inference_logits.sample_id, name='predictions')
# https://www.tensorflow.org/api_docs/python/tf/sequence_mask
# - Returns a mask tensor representing the first N positions of each cell.
masks = tf.sequence_mask(target_sequence_length, max_target_sequence_length, dtype=tf.float32, name='masks')
with tf.name_scope("optimization"):
# Loss function - weighted softmax cross entropy
cost = tf.contrib.seq2seq.sequence_loss(
training_logits,
targets,
masks)
# Optimizer
optimizer = tf.train.AdamOptimizer(lr)
# Gradient Clipping
gradients = optimizer.compute_gradients(cost)
capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]
train_op = optimizer.apply_gradients(capped_gradients)