Skip to content

Commit 9d90125

Browse files
committed
add synthetic gc content concept & test
1 parent 8f90832 commit 9d90125

3 files changed

Lines changed: 106 additions & 17 deletions

File tree

test/test_cav_trainer.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def test_run_tpcav_random_control(self):
115115
genome_fasta=genome_fasta,
116116
num_motif_insertions=[4, 8],
117117
bed_seq_file="data/hg38_rmsk.sample.bed",
118+
synthetic_gc_concept_step=0.1,
118119
output_dir="data/test_run_tpcav_output/",
119120
)
120121

@@ -234,13 +235,37 @@ def test_motif_concepts_against_permute_control(self):
234235

235236
cav_trainer.train_concepts_pairs(concepts_pairs, 200, output_dir="data/cavs_permute/", num_processes=2)
236237

238+
def test_synthetic_gc_content_concept(self):
239+
240+
builder = ConceptBuilder(
241+
genome_fasta="data/hg38.analysisSet.fa",
242+
input_window_length=1024,
243+
bws=None,
244+
num_motifs=12,
245+
include_reverse_complement=True,
246+
min_samples=1000,
247+
batch_size=8,
248+
)
249+
250+
builder.build_control()
251+
252+
builder.add_synthetic_gc_content_concepts(0.1)
253+
254+
for c in builder.concepts:
255+
gc_content = float(c.name.split('_')[-1])
256+
seq, _ = next(iter(c.data_iter))
257+
total_c = 0.; gc_c = 0.
258+
for s in seq:
259+
total_c += len(s)
260+
gc_c += s.count('G') + s.count('C')
261+
assert (gc_c/total_c) - gc_content < 0.05
237262

238263
def test_all(self):
239-
lp = LineProfiler()
240-
# Add installed-package functions you care about
241-
lp.add_function(utils.iterate_seq_df_chunk)
242-
lp.add_function(CavTrainer.train_concepts)
243-
lp.enable_by_count()
264+
#lp = LineProfiler()
265+
# # Add installed-package functions you care about
266+
#lp.add_function(utils.iterate_seq_df_chunk)
267+
#lp.add_function(CavTrainer.train_concepts)
268+
#lp.enable_by_count()
244269

245270
motif_path = Path("data") / "motif-clustering-v2.1beta_consensus_pwms.test.meme"
246271
self.assertTrue(motif_path.exists(), "Motif file is missing")
@@ -259,6 +284,8 @@ def test_all(self):
259284

260285
builder.add_meme_motif_concepts(str(motif_path))
261286

287+
builder.add_synthetic_gc_content_concepts(0.1)
288+
262289
builder.apply_transform(transform_fasta_to_one_hot_seq)
263290

264291
batch = next(iter(builder.all_concepts()[0].data_iter))
@@ -370,7 +397,7 @@ def forward_from_layer_1_embeddings(tm, avs_residual, avs_projected):
370397
f"Attributions do not match, max difference is {torch.abs(attributions - attributions_old).max()}",
371398
)
372399

373-
lp.disable_by_count()
400+
#lp.disable_by_count()
374401
#lp.print_stats()
375402

376403
# test save and restore states

tpcav/cavs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def run_tpcav(
877877
motif_control_type="random",
878878
bed_seq_file: Optional[str] = None,
879879
bed_chrom_file: Optional[str] = None,
880+
synthetic_gc_concept_step: Optional[float] = None,
880881
layer_name: Optional[str]=None,
881882
layer=None,
882883
output_dir: str = "tpcav/",
@@ -941,8 +942,8 @@ def run_tpcav(
941942
motif_concept_builders[nm] = builder
942943

943944
## bed concepts (optional)
944-
if bed_seq_file is not None or bed_chrom_file is not None:
945-
bed_builder = ConceptBuilder(
945+
if bed_seq_file is not None or bed_chrom_file is not None or synthetic_gc_concept_step is not None:
946+
non_motif_concept_builder = ConceptBuilder(
946947
genome_fasta=genome_fasta,
947948
input_window_length=input_window_length,
948949
bws=bws,
@@ -953,23 +954,26 @@ def run_tpcav(
953954
rng_seed = seed,
954955
)
955956
# use random regions as control
956-
bed_builder.build_control()
957+
non_motif_concept_builder.build_control()
957958
if bed_seq_file is not None:
958959
# build concepts from fasta sequences in bed file
959-
bed_builder.add_bed_sequence_concepts(bed_seq_file)
960+
non_motif_concept_builder.add_bed_sequence_concepts(bed_seq_file)
960961
if bed_chrom_file is not None:
961962
# build concepts from chromatin tracks in bed file
962-
bed_builder.add_bed_chrom_concepts(bed_chrom_file)
963+
non_motif_concept_builder.add_bed_chrom_concepts(bed_chrom_file)
964+
if synthetic_gc_concept_step is not None:
965+
# build synthetic gc content concepts
966+
non_motif_concept_builder.add_synthetic_gc_content_concepts(synthetic_gc_concept_step)
963967
# apply transform to convert fasta sequences to one-hot encoded sequences
964-
bed_builder.apply_transform(input_transform_func)
968+
non_motif_concept_builder.apply_transform(input_transform_func)
965969
else:
966-
bed_builder = None
970+
non_motif_concept_builder = None
967971

968972
# create TPCAV model on top of the given model
969973
tpcav_model = TPCAV(model, layer_name=layer_name, layer=layer)
970974
# fit PCA on sampled all concept activations of the last builder (should have the most motifs)
971975
tpcav_model.fit_pca(
972-
concepts=motif_concept_builders[num_motif_insertions[-1]].concepts_for_pca() + bed_builder.concepts_for_pca() if bed_builder is not None else motif_concept_builders[num_motif_insertions[-1]].concepts_for_pca(),
976+
concepts=motif_concept_builders[num_motif_insertions[-1]].concepts_for_pca() + non_motif_concept_builder.concepts_for_pca() if non_motif_concept_builder is not None else motif_concept_builders[num_motif_insertions[-1]].concepts_for_pca(),
973977
num_samples_per_concept=num_samples_for_pca,
974978
num_pc=num_pc,
975979
)
@@ -993,13 +997,14 @@ def run_tpcav(
993997
if save_cav_trainer:
994998
torch.save(cav_trainer, str(output_path / f"cavs_{nm}_motifs/cav_trainer.pt"))
995999
motif_cav_trainers[nm] = cav_trainer
996-
if bed_builder is not None:
1000+
1001+
if non_motif_concept_builder is not None:
9971002
bed_cav_trainer = CavTrainer(tpcav_model, penalty="l2")
9981003
bed_cav_trainer.set_control(
999-
bed_builder.control_concepts[0], num_samples=num_samples_for_cav
1004+
non_motif_concept_builder.control_concepts[0], num_samples=num_samples_for_cav
10001005
)
10011006
bed_cav_trainer.train_concepts(
1002-
bed_builder.concepts,
1007+
non_motif_concept_builder.concepts,
10031008
num_samples_for_cav,
10041009
output_dir=str(output_path / f"cavs_bed_concepts/"),
10051010
num_processes=p,

tpcav/concepts.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ def __iter__(self):
3333
yield inputs
3434

3535

36+
class _SyntheticGCSeqIterator:
37+
def __init__(self, seq_len: int, n: int, batch_size: int, gc: float, seed: int):
38+
self.seq_len = int(seq_len)
39+
self.n = int(n)
40+
self.batch_size = int(batch_size)
41+
self.gc = float(gc)
42+
self.seed = int(seed)
43+
44+
def __iter__(self):
45+
rng = np.random.RandomState(self.seed)
46+
bases = np.array(["A", "C", "G", "T"], dtype="<U1")
47+
p_at = (1.0 - self.gc) / 2.0
48+
p_gc = self.gc / 2.0
49+
p = [p_at, p_gc, p_gc, p_at]
50+
for start in range(0, self.n, self.batch_size):
51+
bs = min(self.batch_size, self.n - start)
52+
arr = rng.choice(4, size=(bs, self.seq_len), p=p)
53+
seqs = ["".join(bases[row]) for row in arr]
54+
yield seqs
55+
3656
def _construct_motif_concept_dataloader_from_control(
3757
control_seq_df: pd.DataFrame,
3858
genome_fasta: str,
@@ -145,6 +165,43 @@ def _control_chrom_dl(self):
145165
)
146166
return chrom_iter
147167

168+
def add_synthetic_gc_content_concepts(self, gc_content_step=0.1):
169+
"""
170+
Add a list of GC content concepts,
171+
according to gc_content_step, example 0.1, GC content of each concept increases from 0.0 to 1.0 by the step
172+
concept iter is basically the same as other add concept function, batch the generated input
173+
"""
174+
step = float(gc_content_step)
175+
if step <= 0 or step > 1:
176+
raise ValueError("gc_content_step must be in (0, 1].")
177+
178+
# Include 1.0 endpoint (within floating tolerance).
179+
gc_values = np.arange(0.0, 1.0 + 1e-9, step, dtype=float)
180+
181+
added: List[Concept] = []
182+
for gc in gc_values:
183+
gc = float(np.clip(gc, 0.0, 1.0))
184+
concept_name = f"synthetic_gc_{gc:.2f}" + self.concept_name_suffix
185+
seed = self.rng_seed + int(round(gc * 10000))
186+
seq_iter = _SyntheticGCSeqIterator(
187+
seq_len=self.input_window_length,
188+
n=self.min_samples,
189+
batch_size=self.batch_size,
190+
gc=gc,
191+
seed=seed,
192+
)
193+
concept = Concept(
194+
id=self._reserve_id(),
195+
name=concept_name,
196+
data_iter=_PairedLoader(seq_iter, self._control_chrom_dl()),
197+
)
198+
self.concepts.append(concept)
199+
added.append(concept)
200+
201+
self.metadata["synthetic_gc_content_step"] = step
202+
self.metadata["synthetic_gc_content_values"] = gc_values.tolist()
203+
return added
204+
148205
def add_custom_motif_concepts(
149206
self, motif_table: str, control_regions: Optional[pd.DataFrame] = None, build_permute_control=True
150207
) -> Union[List[Concept], List[Tuple[Concept]]]:

0 commit comments

Comments
 (0)