Skip to content

Commit 2c560bb

Browse files
Porting latent embeddings to tf2
1 parent 515639b commit 2c560bb

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

joint_embeddings/get_latent_embeddings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ def get_embeddings_low_mem(model, seq_input, chrom_input):
2020
# iterate in batches for processing large datasets.
2121
for batch_start_idx in range(0, len(seq_input), 500):
2222
batch_end_idx = min(batch_start_idx + 500, len(seq_input))
23-
23+
current_batch_seq = seq_input[batch_start_idx:batch_end_idx]
24+
current_batch_chrom = chrom_input[batch_start_idx:batch_end_idx]
25+
print(current_batch_chrom)
2426
with eager_learning_phase_scope(value=0):
25-
sn_activations = np.array(f([seq_input[batch_start_idx:batch_end_idx],
26-
chrom_input[batch_start_idx, batch_end_idx]]))
27+
sn_activations = np.array(f([current_batch_seq,
28+
current_batch_chrom]))
2729
activations_rs = np.reshape(sn_activations, (sn_activations.shape[1], 2))
2830
activations_rs = activations_rs.astype(np.float64)
2931
embedding_list_by_batch.append(activations_rs)
@@ -32,4 +34,5 @@ def get_embeddings_low_mem(model, seq_input, chrom_input):
3234
w, b = model.layers[-1].get_weights()
3335
w = np.reshape(w, (2,))
3436
weighted_embeddings = activations * w
35-
return weighted_embeddings
37+
return weighted_embeddings
38+

0 commit comments

Comments
 (0)