Skip to content

Commit 3d1572e

Browse files
committed
completed eda, data cleaning, intial training loop for teacher
1 parent 4e8f4a7 commit 3d1572e

22 files changed

Lines changed: 915 additions & 1 deletion

.flake8

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
max-line-length = 120
3+
ignore =
4+
E203, # spacing before colon (in conflict with black)
5+
W503 # line break before binary operator

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ data/raw/*
5151
data/processed/*
5252
data/models/*
5353
data/results/*
54+
data/clean/*
55+
!data/clean/.gitkeep
5456
!data/raw/.gitkeep
5557
!data/processed/.gitkeep
5658
!data/models/.gitkeep

configs/deberta_large_sst2.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
model_name: microsoft/deberta-v3-large
2+
dataset_path: data/sst2_dd
3+
train_split: train
4+
eval_split: val
5+
sanity_split: sanity
6+
max_len: 128
7+
per_device_train_batch_size: 8
8+
per_device_eval_batch_size: 32
9+
gradient_accumulation_steps: 4
10+
num_train_epochs: 3
11+
learning_rate: 2e-5
12+
warmup_ratio: 0.06
13+
fp16: true
14+
logging_steps: 50
15+
eval_steps: 200
16+
save_steps: 200
17+
output_dir: outputs/teacher
18+
report_to: wandb
19+
project_name: sst2_teacher

main.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import argparse
2+
import sys
3+
import os
4+
5+
# Ensure the src directory is in the Python path
6+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
7+
8+
9+
try:
10+
from data.clean import run_cleaning_and_split
11+
except ImportError as e:
12+
print(f"Error importing modules: {e}")
13+
print("Please ensure your scripts are correctly placed in the 'src' directory and paths are correct.")
14+
sys.exit(1)
15+
16+
def main():
17+
parser = argparse.ArgumentParser(description="SentiSynth Project Main Entry Point")
18+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
19+
20+
# --- Clean and Split Command ---
21+
parser_process = subparsers.add_parser('process_data', help='Clean raw data and create final train/val/sanity splits')
22+
parser_process.add_argument('--raw-path', default='./data/raw', help='Path to the raw dataset directory')
23+
parser_process.add_argument('--output-path', default='./data/sst2_dd', help='Path to save the final DatasetDict')
24+
parser_process.set_defaults(func=lambda args: run_cleaning_and_split(args.raw_path, args.output_path))
25+
26+
# --- Add other commands here as subparsers ---
27+
# Example: Download command
28+
# parser_download = subparsers.add_parser('download', help='Download the raw dataset')
29+
# parser_download.add_argument('--save-path', default='./data/raw', help='Path to save the raw dataset')
30+
# parser_download.set_defaults(func=lambda args: run_download(args.save_path)) # Assuming you create run_download
31+
32+
# Example: Train command
33+
# parser_train = subparsers.add_parser('train', help='Train a model')
34+
# parser_train.add_argument('--config', required=True, help='Path to the training configuration file')
35+
# ... other training args ...
36+
# parser_train.set_defaults(func=lambda args: run_training(args)) # Assuming you create run_training
37+
38+
# Parse arguments
39+
args = parser.parse_args()
40+
41+
# Execute the function associated with the chosen command
42+
if hasattr(args, 'func'):
43+
args.func(args)
44+
else:
45+
# If no command is given, print help
46+
parser.print_help()
47+
48+
if __name__ == "__main__":
49+
main()

notebooks/eda_sst_2.ipynb

Lines changed: 494 additions & 0 deletions
Large diffs are not rendered by default.

requirements.txt

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,17 @@ scikit-learn>=0.24.2
66
pandas>=1.3.0
77
matplotlib>=3.4.2
88
seaborn>=0.11.1
9-
tqdm>=4.61.2
9+
tqdm>=4.61.2
10+
datasets>=3.5.0
11+
12+
accelerate==1.6.0
13+
14+
# Configuration and Metrics
15+
PyYAML>=6.0.2
16+
scikit-learn>=1.6.1
17+
18+
# Logging
19+
wandb>=0.19.9
20+
Cmake
21+
sentencepiece
22+
protobuf

scripts/download_dataset.py

Whitespace-only changes.

scripts/train_model.py

Whitespace-only changes.

sentisynth/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)