-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexplore_pytorch.py
More file actions
137 lines (100 loc) · 4.59 KB
/
explore_pytorch.py
File metadata and controls
137 lines (100 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Example demonstrating how to use user-defined augmentation classes.
This shows how to define augmentation classes with @local_aug and @global_aug decorators.
"""
from torch.utils.data import DataLoader
from pyspark.sql import SparkSession
from dataaug_platform import Augmentation, local_aug, global_aug, Pipeline, SparkIterableDataset
# ============================================================================
# Example 1: User-defined local augmentation class
# ============================================================================
class AddOffsetAugmentation(Augmentation):
"""Add a constant offset to all numeric values."""
def __init__(self, offset=1.0):
self.offset = offset
@local_aug
def apply(self, traj):
"""Process one trajectory at a time."""
import numpy as np
modified = traj.copy()
for key, value in modified.items():
if isinstance(value, np.ndarray):
modified[key] = value + self.offset
return modified
# ============================================================================
# Example 2: User-defined global augmentation class
# ============================================================================
class AverageTrajectoriesAugmentation(Augmentation):
"""Create a new trajectory by averaging all trajectories."""
def __init__(self, times=1, keep_original=True):
"""
Initialize the augmentation.
Args:
times: Number of times to run this augmentation (default: 1).
Each run processes the whole dataset and produces new trajectories.
Multiple runs are parallelized using Spark.
keep_original: Whether to keep original trajectories in output (default: True).
If False, output contains only augmented trajectories.
"""
super().__init__(times=times, keep_original=keep_original)
@global_aug
def apply(self, trajs):
"""Process all trajectories together."""
import numpy as np
if not trajs:
return []
# Average all trajectories
avg_traj = {}
for key in trajs[0].keys():
values = [traj[key] for traj in trajs if key in traj]
if values and isinstance(values[0], np.ndarray):
avg_traj[key] = np.mean(values, axis=0)
else:
avg_traj[key] = values[0] if values else None
return [avg_traj]
# ============================================================================
# Example: Using the pipeline with class-based augmentations
# ============================================================================
def example():
"""Demonstrates the class-based augmentation style."""
spark = SparkSession.builder.appName("AugmentationExample").getOrCreate()
pipeline = Pipeline(spark)
# Add user-defined augmentation classes
pipeline.add(AddOffsetAugmentation(offset=2.0))
# Run global augmentation 3 times in parallel using Spark
# keep_original=True (default): output has 5 (original) + 3 (augmented) = 8 trajectories
pipeline.add(AverageTrajectoriesAugmentation(times=3))
# Example with keep_original=False: output has only 5 augmented trajectories
# pipeline.add(AverageTrajectoriesAugmentation(times=3, keep_original=False))
# Sample data: list of trajectory dictionaries
sample_data = [
{"x": [1, 2, 3], "y": [4, 5, 6]},
{"x": [2, 3, 4], "y": [5, 6, 7]},
{"x": [3, 4, 5], "y": [6, 7, 8]},
{"x": [7, 8, 9], "y": [10, 11, 12]},
{"x": [11, 12, 13], "y": [14, 15, 16]},
]
pipeline.set_data(sample_data)
print("===== Finite Dataset =====")
spark_dataset = SparkIterableDataset(pipeline)
spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1)
for i, data in enumerate(spark_dataloader):
print(f'[{i}] {data=}')
print("===== Infinite Dataset =====")
num_batches = 4
batch_count = 0
print(f"Iteration count {num_batches}")
spark_dataset = SparkIterableDataset(pipeline, infinite=True)
spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1)
for i, data in enumerate(spark_dataloader):
if batch_count == num_batches:
print("Iteration count reached")
break
print(f'[{i}] {data=}')
batch_count += 1
spark.stop()
if __name__ == "__main__":
print("=" * 60)
print("Example: Class-based augmentation style")
print("=" * 60)
example()