From 5be00c8b15a41c520481acc9a67f86964124c1af Mon Sep 17 00:00:00 2001 From: mcclain Date: Fri, 7 Nov 2025 14:06:48 +0000 Subject: [PATCH] added sweep info and better logging --- docker-compose.yaml | 8 +- src/config.py | 10 + src/eval/eval.py | 365 ++++++++++++++++++ src/eval/eval_config.py | 33 ++ src/rewards/bioinformatics/reward_config.py | 4 +- src/runners/grpo.py | 125 ++++-- src/runners/grpo_sweep.py | 50 ++- src/utils/training_utils.py | 243 ++++++++++++ .../sweep_config_training_with_eval.yaml | 144 +++++++ 9 files changed, 946 insertions(+), 36 deletions(-) create mode 100644 src/eval/eval.py create mode 100644 src/eval/eval_config.py create mode 100644 src/utils/training_utils.py create mode 100644 sweeps/configs/sweep_config_training_with_eval.yaml diff --git a/docker-compose.yaml b/docker-compose.yaml index be91ca5..d8f1338 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -41,8 +41,8 @@ services: " environment: # W&B settings for this job - - WANDB_ENTITY=mcclain - - WANDB_PROJECT=plasmidrl-trl-grpo + - WANDB_ENTITY=ucl-cssb + - WANDB_PROJECT=PlasmidRL - WANDB_TAGS=["plasmid","rl","trl","grpo"] - WANDB_NOTES=TRL GRPO training on plasmid design - WANDB_DIR=/tmp/wandb @@ -76,8 +76,8 @@ services: uv run wandb agent ${SWEEP_ID} " environment: - - WANDB_ENTITY=mcclain - - WANDB_PROJECT=plasmidrl-grpo-sweeps + - WANDB_ENTITY=ucl-cssb + - WANDB_PROJECT=PlasmidRL - WANDB_DIR=/tmp/wandb - SWEEP_ID=${SWEEP_ID} - PYTHONPATH=/mcclain diff --git a/src/config.py b/src/config.py index 872d0b2..1b1ad0d 100644 --- a/src/config.py +++ b/src/config.py @@ -47,6 +47,16 @@ class Config(BaseSettings): region_name: str = "us-east-1" runs_path: str = "runs/" infered_path: str = "infered/" + checkpoints_path: str = "checkpoints/" # S3 prefix for checkpoint storage + + # Production GRPO hyperparameters (from sweep optimization) + grpo_learning_rate: float = 0.00001906419115928539 + grpo_per_device_train_batch_size: int = 16 + grpo_num_generations: int = 4 + grpo_temperature: float = 1.2292317925218237 + grpo_top_p: float = 0.9086524230707756 + grpo_beta: float = 0.00088482365318492 + grpo_epsilon: float = 0.2649093053949679 model_config = { "env_file": ".env", diff --git a/src/eval/eval.py b/src/eval/eval.py new file mode 100644 index 0000000..03639b3 --- /dev/null +++ b/src/eval/eval.py @@ -0,0 +1,365 @@ +from vllm import LLM, SamplingParams +from typing import Optional, List, Dict, Any +from src.eval.eval_config import EvalConfig +from src.utils.training_utils import EvalRunner +from src.config import Config +import pandas as pd +import os +import plasmidkit as pk +import re + + +class _Feat: + """Simple feature container for annotation merging (same as Scorer).""" + def __init__(self, type: str, id: str | None, start: int, end: int, strand: str | None, evidence: Any = None): + self.type = type + self.id = id + self.start = int(start) + self.end = int(end) + self.strand = strand or "+" + self.evidence = evidence or {} + + +class SequenceAnalyzer: + """ + Analyzes plasmid sequences and extracts detailed annotation information. + + Similar to Scorer but returns detailed information instead of scores. + Uses plasmidkit for annotation and extracts counts and IDs for each feature type. + """ + + def __init__(self, eval_config: EvalConfig): + self.eval_config = eval_config + + def annotate(self, sequence: str) -> List[Any]: + """Annotate sequence with plasmidkit and merge overlapping features.""" + assert sequence, "sequence cannot be empty" + raw = pk.annotate(sequence, is_sequence=True) + return self._preprocess_annotations(raw) + + @staticmethod + def _overlap_len(a: Any, b: Any) -> int: + """Calculate overlap length between two features.""" + s1, e1 = int(a.start), int(a.end) + s2, e2 = int(b.start), int(b.end) + lo = max(min(s1, e1), min(s2, e2)) + hi = min(max(s1, e1), max(s2, e2)) + return max(0, hi - lo) + + def _to_feat(self, x: Any) -> _Feat: + """Convert annotation object to internal _Feat representation.""" + return _Feat( + type=x.type.lower() if x.type else "", + id=x.id if hasattr(x, "id") else None, + start=int(x.start), + end=int(x.end), + strand=x.strand if hasattr(x, "strand") else "+", + evidence=x.evidence if hasattr(x, "evidence") else {}, + ) + + def _merge_group(self, feats: List[Any], threshold: float, *, respect_strand: bool) -> List[_Feat]: + """Merge overlapping features of the same type based on overlap threshold.""" + if not feats: + return [] + items = [self._to_feat(f) for f in feats] + items.sort(key=lambda f: (f.strand, f.start, f.end)) + merged: List[_Feat] = [] + cur = items[0] + for nxt in items[1:]: + ovl = self._overlap_len(cur, nxt) + cur_len = max(0, cur.end - cur.start) + nxt_len = max(0, nxt.end - nxt.start) + min_len = max(1, min(cur_len, nxt_len)) + strands_compatible = (cur.strand == nxt.strand) or (not respect_strand) + if ovl / float(min_len) >= threshold and strands_compatible: + cur.start = min(cur.start, nxt.start) + cur.end = max(cur.end, nxt.end) + cur.id = f"{cur.id}|{nxt.id}" if cur.id or nxt.id else None + else: + merged.append(cur) + cur = nxt + merged.append(cur) + return merged + + def _preprocess_annotations(self, annotations: Any) -> List[Any]: + """ + Merge overlapping annotations and filter out CDS overlapping with other feature types. + + Same logic as Scorer._preprocess_annotations. + """ + feats = list(annotations) + thr = float(self.eval_config.overlap_merge_threshold) + type_key = lambda x: x.type.lower() if x.type else "" + + # Collect groups by type + groups: Dict[str, List[Any]] = {} + for f in feats: + groups.setdefault(type_key(f), []).append(f) + + # Merge per group for relevant types + merged_groups: Dict[str, List[_Feat]] = {} + for t in ("rep_origin", "ori", "origin_of_replication", "promoter", "terminator", "marker", "cds"): + if t in groups: + respect = t not in ("rep_origin", "ori", "origin_of_replication", "marker") + merged_groups[t] = self._merge_group(groups[t], thr, respect_strand=respect) + + # Suppress CDS if overlaps any non-CDS + non_cds: List[_Feat] = [] + for t in ("rep_origin", "ori", "origin_of_replication", "promoter", "terminator", "marker"): + non_cds.extend(merged_groups.get(t, [])) + + filtered_cds: List[_Feat] = [] + for c in merged_groups.get("cds", []): + if any(self._overlap_len(c, o) > 0 for o in non_cds): + continue + filtered_cds.append(c) + merged_groups["cds"] = filtered_cds + + # Rebuild final list + final: List[Any] = [] + merged_types = set(merged_groups.keys()) + for t, items in merged_groups.items(): + final.extend(items) + for f in feats: + t = type_key(f) + if t not in merged_types: + final.append(f) + return final + + def analyze(self, sequence: str) -> Dict[str, Any]: + """ + Analyze a sequence and extract detailed annotation information. + + Args: + sequence: DNA sequence to analyze + + Returns: + Dictionary with counts and IDs for each feature type + """ + annotations = self.annotate(sequence) + feats = list(annotations) + type_key = lambda x: x.type.lower() if x.type else "" + + # Extract features by type (no filtering - report all IDs found) + oris = [x for x in feats if type_key(x) in ("rep_origin", "ori", "origin_of_replication")] + promoters = [x for x in feats if type_key(x) == "promoter"] + terminators = [x for x in feats if type_key(x) == "terminator"] + markers = [x for x in feats if type_key(x) == "marker"] + cdss = [x for x in feats if type_key(x) == "cds"] + + # Extract IDs (handle merged IDs separated by |) + def extract_ids(features: List[Any]) -> List[str]: + ids = [] + for f in features: + if hasattr(f, "id") and f.id: + # Split merged IDs (separated by |) + ids.extend([id.strip() for id in str(f.id).split("|") if id.strip()]) + return ids + + return { + "ori_count": len(oris), + "ori_ids": ",".join(extract_ids(oris)) if oris else "", + "promoter_count": len(promoters), + "promoter_ids": ",".join(extract_ids(promoters)) if promoters else "", + "terminator_count": len(terminators), + "terminator_ids": ",".join(extract_ids(terminators)) if terminators else "", + "marker_count": len(markers), + "marker_ids": ",".join(extract_ids(markers)) if markers else "", + "cds_count": len(cdss), + "cds_ids": ",".join(extract_ids(cdss)) if cdss else "", + } + + +class Evaluator(EvalRunner): + """ + Evaluator that generates rollouts from a checkpoint and analyzes them. + + Loads prompts from CSV, generates samples using vLLM, analyzes each sequence + with plasmidkit, and returns a DataFrame with detailed annotation information. + """ + + def __init__(self, config: EvalConfig): + """ + Initialize the evaluator. + + Args: + config: Evaluation configuration + """ + self.config = config + self.llm: Optional[LLM] = None + self.base_config = Config() # For default prompts + self.analyzer = SequenceAnalyzer(config) + + def run_with_trainer(self, trainer: Any, wandb_run: Optional[Any] = None) -> pd.DataFrame: + """ + Run evaluation using the trainer's model directly (already loaded on GPU). + + Args: + trainer: Trainer instance with vLLM model already loaded + wandb_run: Optional wandb run object for logging + + Returns: + DataFrame with detailed annotation information for each sequence + """ + # Load prompts + prompts = self._load_prompts() + + if not prompts: + print("[Evaluator] Warning: No prompts loaded, returning empty DataFrame") + return pd.DataFrame() + + # Use trainer's vLLM instance directly + llm = trainer.llm if hasattr(trainer, 'llm') else None + if llm is None: + print("[Evaluator] Warning: Trainer does not have 'llm' attribute, cannot use in-memory model") + return pd.DataFrame() + + print(f"[Evaluator] Using trainer's vLLM instance directly (no model reload needed)") + self.llm = llm + + # Get sampling parameters + sampling_params = self.config.sampling_params + if sampling_params is None: + sampling_params = SamplingParams( + max_tokens=512, + temperature=0.8, + top_p=0.95, + top_k=0, + ) + + # Expand prompts for multiple samples per prompt + expanded_prompts = [] + prompt_indices = [] + for i, prompt in enumerate(prompts): + for _ in range(self.config.num_samples_per_prompt): + expanded_prompts.append(prompt) + prompt_indices.append(i) + + print(f"[Evaluator] Generating {len(expanded_prompts)} samples from {len(prompts)} prompts") + + # Generate rollouts + outputs = self.llm.generate(expanded_prompts, sampling_params) + + # Process results and analyze each sequence + records = [] + for idx, output in enumerate(outputs): + prompt = output.prompt + completion = output.outputs[0].text.replace(" ", "") + full = prompt + completion + + # Clean sequence for analysis (remove non-DNA characters) + cleaned_full = re.sub(r'[^ATCG]', '', full.upper()) + + # Analyze sequence + try: + analysis = self.analyzer.analyze(cleaned_full) + except Exception as e: + print(f"[Evaluator] Warning: Failed to analyze sequence {idx}: {e}") + analysis = { + "ori_count": 0, + "ori_ids": "", + "promoter_count": 0, + "promoter_ids": "", + "terminator_count": 0, + "terminator_ids": "", + "marker_count": 0, + "marker_ids": "", + "cds_count": 0, + "cds_ids": "", + } + + records.append({ + "prompt": prompt, + "response": completion, + "full": full, + "length": len(cleaned_full), + "completion_length": len(completion), + "full_length": len(full), + **analysis, + }) + + # Convert to DataFrame + df = pd.DataFrame(records) + + print(f"[Evaluator] Generated and analyzed {len(df)} rollouts") + + return df + + def _load_prompts(self) -> List[str]: + """ + Load prompts from CSV/parquet file. + + Returns: + List of prompt strings + """ + if not self.config.prompts_path: + print("[Evaluator] Warning: No prompts_path specified, using default prompts") + return [self._get_default_prompts()] + + prompts_path = self.config.prompts_path + + # Check if file exists + if not os.path.exists(prompts_path): + print(f"[Evaluator] Warning: Prompts file not found: {prompts_path}") + return [self._get_default_prompts()] + + try: + # Load based on file extension + if prompts_path.endswith('.parquet'): + df = pd.read_parquet(prompts_path) + elif prompts_path.endswith('.csv'): + df = pd.read_csv(prompts_path) + else: + print(f"[Evaluator] Warning: Unsupported file format: {prompts_path}") + return [self._get_default_prompts()] + + # Extract prompts column + if self.config.prompts_column not in df.columns: + print(f"[Evaluator] Warning: Column '{self.config.prompts_column}' not found in prompts file") + print(f"[Evaluator] Available columns: {list(df.columns)}") + return [self._get_default_prompts()] + + prompts = df[self.config.prompts_column].dropna().tolist() + prompts = [str(p).strip() for p in prompts if str(p).strip()] + + print(f"[Evaluator] Loaded {len(prompts)} prompts from {prompts_path}") + return prompts + + except Exception as e: + print(f"[Evaluator] Error loading prompts: {e}") + return [self._get_default_prompts()] + + def _get_default_prompts(self) -> str: + """Get default prompt if prompts file is not available.""" + # Use the default query from config (GFP cassette) + return self.base_config.default_query + + def _initialize_model(self, checkpoint_path: str) -> None: + """ + Initialize vLLM model from checkpoint path. + + Args: + checkpoint_path: Path to model checkpoint + """ + # Check if model is already initialized for this checkpoint + if self.llm is not None: + # For now, always reinitialize - could optimize later + pass + + print(f"[Evaluator] Loading model from checkpoint: {checkpoint_path}") + + try: + # Initialize vLLM with checkpoint path + # vLLM can load from local checkpoint directories + model_kwargs = { + "trust_remote_code": True, + } + + self.llm = LLM(model=checkpoint_path, **model_kwargs) + print(f"[Evaluator] Model loaded successfully") + + except Exception as e: + print(f"[Evaluator] Error loading model: {e}") + import traceback + traceback.print_exc() + raise \ No newline at end of file diff --git a/src/eval/eval_config.py b/src/eval/eval_config.py new file mode 100644 index 0000000..27d361b --- /dev/null +++ b/src/eval/eval_config.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel, ConfigDict +from vllm import SamplingParams +from typing import Optional + +class EvalConfig(BaseModel): + """ + Configuration for evaluation analysis. + + Similar to RewardConfig but focused on detailed annotation extraction + rather than scoring. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Model configuration + model_name: str + model_path: str + + # Prompts configuration + prompts_path: Optional[str] = None # Path to CSV/parquet file with prompts + prompts_column: str = "prompt" # Column name containing prompts + num_samples_per_prompt: int = 10 # Number of samples to generate per prompt + + # Annotation configuration (similar to RewardConfig) + overlap_merge_threshold: float = 0.8 # Overlap merge threshold for annotations + + # Generation configuration + sampling_params: Optional[SamplingParams] = None + + # Logging configuration + write_to_wandb: bool = False + wandb_project: Optional[str] = None + wandb_run_name: Optional[str] = None + diff --git a/src/rewards/bioinformatics/reward_config.py b/src/rewards/bioinformatics/reward_config.py index 5de9774..86ec753 100644 --- a/src/rewards/bioinformatics/reward_config.py +++ b/src/rewards/bioinformatics/reward_config.py @@ -12,7 +12,7 @@ class RewardConfig(BaseModel): length_reward_bonus: float = 0.5 # bonus multiplier for sequences in ideal length range location_aware: bool = True # reward sequences that are located in the correct location (e.g. promoter then cds then terminator) # Penalty factor applied when min/max constraints are violated (outside of range) - violation_penalty_factor: float = 0.5 + violation_penalty_factor: float = 1.0 # Deprecated - use length_reward_mode instead length_penalty: bool = False @@ -20,7 +20,7 @@ class RewardConfig(BaseModel): ori_min: int = 1 ori_max: int = 1 allowed_oris: Optional[List[str]] = None - ori_weight: float = 1.0 + ori_weight: float = 1.5 promoter_min: int = 1 promoter_max: int = 1 diff --git a/src/runners/grpo.py b/src/runners/grpo.py index 9184df0..fea171f 100644 --- a/src/runners/grpo.py +++ b/src/runners/grpo.py @@ -6,16 +6,21 @@ from src.rewards.bioinformatics.scorer import Scorer from src.rewards.bioinformatics.reward_config import RewardConfig from src.rewards.bioinformatics.logger import RewardComponentLogger +from src.eval.eval import Evaluator +from src.eval.eval_config import EvalConfig +from src.utils.training_utils import EvalCallback, test_checkpoint_directory_write +from vllm import SamplingParams import datetime from typing import List import wandb from concurrent.futures import ThreadPoolExecutor from threading import Lock import re +import os # Configuration cfg = Config() -run_name = f"grpo-{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" +run_name = f"grpo-production-{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" # Dataset loading def load_train_val_datasets(): @@ -52,17 +57,22 @@ def select_prompt_column(ds): "pad_token_id": tok.pad_token_id, } -# Training configuration +# Training configuration - use /s3 mount point with prefix path +checkpoint_dir = f"/s3/{cfg.checkpoints_path.rstrip('/')}/grpo-production/{run_name}" + +# Test checkpoint directory write access before proceeding +test_checkpoint_directory_write(checkpoint_dir) + args = GRPOConfig( model_init_kwargs=model_init_kwargs, - output_dir=f"/s3/checkpoints/verl-grpo/{run_name}", + output_dir=checkpoint_dir, # Training parameters num_train_epochs=20, - learning_rate=3e-6, + learning_rate=cfg.grpo_learning_rate, lr_scheduler_type="constant", warmup_ratio=0.0, - per_device_train_batch_size=16, + per_device_train_batch_size=cfg.grpo_per_device_train_batch_size, gradient_accumulation_steps=1, max_steps=-1, max_grad_norm=0.5, @@ -71,6 +81,7 @@ def select_prompt_column(ds): # Logging and checkpointing save_strategy="steps", save_steps=100, + save_total_limit=5, # Keep last 5 checkpoints logging_strategy="steps", logging_steps=1, report_to=["wandb"], @@ -78,15 +89,15 @@ def select_prompt_column(ds): # Evaluation do_eval=True, eval_strategy="steps", - eval_steps=100, + eval_steps=50, # Evaluate every 50 steps # Optimization bf16=torch.cuda.is_available(), gradient_checkpointing=False, # GRPO-specific - beta=1e-3, - epsilon=0.2, + beta=cfg.grpo_beta, + epsilon=cfg.grpo_epsilon, loss_type="bnpo", scale_rewards=True, mask_truncated_completions=False, @@ -95,10 +106,10 @@ def select_prompt_column(ds): # Generation parameters remove_unused_columns=False, max_prompt_length=1024, - num_generations=8, + num_generations=cfg.grpo_num_generations, max_completion_length=256, - temperature=0.95, - top_p=0.90, + temperature=cfg.grpo_temperature, + top_p=cfg.grpo_top_p, # vLLM configuration use_vllm=True, @@ -106,20 +117,31 @@ def select_prompt_column(ds): vllm_mode="colocate", ) -# Reward configuration +# Reward configuration - production parameters from sweep reward_config = RewardConfig( - punish_mode=False, - length_penalty=True, - min_length=1000, + punish_mode=True, # Use punish mode for better constraint learning + length_reward_mode=True, + min_length=2000, max_length=30000, + ideal_min_length=7000, + ideal_max_length=20000, + length_reward_bonus=0.7085046275614012, ori_min=1, ori_max=1, + ori_weight=1.0, promoter_min=1, promoter_max=5, + promoter_weight=1.0, terminator_min=0, terminator_max=2, + terminator_weight=0.5, marker_min=1, marker_max=2, + marker_weight=1.0, + cds_min=1, + cds_max=5, + cds_weight=1.0, + location_aware=True, ) # Initialize scorer and logger @@ -127,6 +149,27 @@ def select_prompt_column(ds): reward_logger = RewardComponentLogger(log_frequency=1) component_lock = Lock() +# Initialize evaluation callback +eval_config = EvalConfig( + model_name=cfg.model, + model_path=cfg.model, # Will be overridden by checkpoint path in callback + prompts_path=cfg.val_dataset, # Use test.parquet for evaluation prompts + prompts_column="prompt", + num_samples_per_prompt=5, # Fewer samples for quick testing + overlap_merge_threshold=0.8, + sampling_params=SamplingParams( + max_tokens=256, + temperature=0.95, + top_p=0.90, + top_k=0, + ), + write_to_wandb=True, + wandb_project=cfg.wandb_project, + wandb_run_name=run_name, +) +evaluator = Evaluator(eval_config) +eval_callback = EvalCallback(evaluator) + # Reward function def score_single(idx_and_seq): """Score a single sequence and log components thread-safely.""" @@ -161,28 +204,38 @@ def batch_reward_fn(prompts: List[str], completions: List[str], **kwargs) -> Lis return [r[0] for r in results] # Initialize W&B -wandb.init( +wandb_run = wandb.init( project=cfg.wandb_project, entity=cfg.wandb_entity, name=run_name, + tags=["production", "grpo", "optimized-hyperparams"], config={ "model": cfg.model, "reward_config": reward_config.model_dump(), "training": { - "learning_rate": args.learning_rate, - "batch_size": args.per_device_train_batch_size, + "learning_rate": cfg.grpo_learning_rate, + "batch_size": cfg.grpo_per_device_train_batch_size, "num_epochs": args.num_train_epochs, + "num_generations": cfg.grpo_num_generations, }, "grpo": { - "beta": args.beta, - "epsilon": args.epsilon, - "temperature": args.temperature, - "num_generations": args.num_generations, + "beta": cfg.grpo_beta, + "epsilon": cfg.grpo_epsilon, + "temperature": cfg.grpo_temperature, + "top_p": cfg.grpo_top_p, "loss_type": args.loss_type, }, + "checkpoint_dir": checkpoint_dir, }, ) +# Print wandb URL and checkpoint info +if wandb_run: + print(f"\n{'='*80}") + print(f"🚀 W&B Run URL: {wandb_run.url}") + print(f"📁 Checkpoint Directory: {checkpoint_dir}") + print(f"{'='*80}\n") + # Initialize trainer trainer = GRPOTrainer( model=cfg.model, @@ -191,10 +244,32 @@ def batch_reward_fn(prompts: List[str], completions: List[str], **kwargs) -> Lis train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tok, - callbacks=[reward_logger], + callbacks=[reward_logger, eval_callback], ) +# Set trainer reference in callback (for accessing trainer.llm) +eval_callback.set_trainer(trainer) + # Train and save +print(f"Starting training with {args.num_train_epochs} epochs...") trainer.train() -trainer.save_model(args.output_dir) -tok.save_pretrained(args.output_dir) + +# Save final model and tokenizer +print(f"Saving final model to {checkpoint_dir}...") +trainer.save_model(checkpoint_dir) +tok.save_pretrained(checkpoint_dir) + +# Log final checkpoint as W&B artifact +artifact = wandb.Artifact( + name=f"model-{run_name}", + type="model", + description=f"Final GRPO model checkpoint from production run with optimized hyperparameters", +) +artifact.add_dir(checkpoint_dir) +wandb_run.log_artifact(artifact) + +print(f"✓ Training complete! Model saved to {checkpoint_dir}") +print(f"✓ Model artifact logged to W&B: {wandb_run.url}") + +# Finish run +wandb.finish() diff --git a/src/runners/grpo_sweep.py b/src/runners/grpo_sweep.py index 44d2ec3..6f95db3 100644 --- a/src/runners/grpo_sweep.py +++ b/src/runners/grpo_sweep.py @@ -11,12 +11,17 @@ from src.rewards.bioinformatics.scorer import Scorer from src.rewards.bioinformatics.reward_config import RewardConfig from src.rewards.bioinformatics.logger import RewardComponentLogger +from src.eval.eval import Evaluator +from src.eval.eval_config import EvalConfig +from src.utils.training_utils import EvalCallback, test_checkpoint_directory_write +from vllm import SamplingParams import datetime from typing import List import wandb from concurrent.futures import ThreadPoolExecutor from threading import Lock import re +import os def load_train_val_datasets(cfg): @@ -39,7 +44,11 @@ def main(): # Initialize W&B run with a meaningful name run_name = f"grpo-sweep-{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" - run = wandb.init(name=run_name) + run = wandb.init( + name=run_name, + entity=cfg.wandb_entity, + project=cfg.wandb_project, + ) # Get sweep params sweep_config = wandb.config @@ -65,8 +74,12 @@ def main(): "pad_token_id": tok.pad_token_id, } - # Training configuration with sweep parameters - checkpoint_dir = f"/mnt/s3/phd-research-storage-1758274488/checkpoints/grpo-sweeps/{run_name}" + # Training configuration - use /s3 mount point with prefix path + checkpoint_dir = f"/s3/{cfg.checkpoints_path.rstrip('/')}/grpo-sweeps/{run_name}" + + # Test checkpoint directory write access before proceeding + test_checkpoint_directory_write(checkpoint_dir) + args = GRPOConfig( model_init_kwargs=model_init_kwargs, output_dir=checkpoint_dir, @@ -93,7 +106,7 @@ def main(): # Evaluation do_eval=True, eval_strategy="steps", - eval_steps=100, + eval_steps=50, # More frequent eval for sweeps to track progress # Optimization bf16=torch.cuda.is_available(), @@ -148,11 +161,35 @@ def main(): location_aware=sweep_config.get("reward_location_aware", True), ) + # Log reward_config to wandb + wandb.config.update({"reward_config": reward_config.model_dump()}) + # Initialize scorer and logger scorer = Scorer(reward_config) reward_logger = RewardComponentLogger(log_frequency=5) # Log frequently for short runs component_lock = Lock() + # Initialize evaluation callback + eval_config = EvalConfig( + model_name=cfg.model, + model_path=cfg.model, + prompts_path=cfg.val_dataset, # Use test.parquet for evaluation prompts + prompts_column="prompt", + num_samples_per_prompt=5, # Fewer samples for sweeps + overlap_merge_threshold=0.8, + sampling_params=SamplingParams( + max_tokens=256, + temperature=0.95, + top_p=0.90, + top_k=0, + ), + write_to_wandb=True, + wandb_project=cfg.wandb_project, + wandb_run_name=run_name, + ) + evaluator = Evaluator(eval_config) + eval_callback = EvalCallback(evaluator) + # Reward function def score_single(idx_and_seq): idx, seq = idx_and_seq @@ -179,9 +216,12 @@ def batch_reward_fn(prompts: List[str], completions: List[str], **kwargs) -> Lis train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tok, - callbacks=[reward_logger], + callbacks=[reward_logger, eval_callback], ) + # Set trainer reference in callback (for accessing trainer.llm) + eval_callback.set_trainer(trainer) + # Train trainer.train() diff --git a/src/utils/training_utils.py b/src/utils/training_utils.py new file mode 100644 index 0000000..91af73f --- /dev/null +++ b/src/utils/training_utils.py @@ -0,0 +1,243 @@ +from transformers import TrainerCallback +from typing import Any, Protocol, Optional +import wandb +import pandas as pd +import os +import sys +import datetime +from abc import ABC, abstractmethod + + +def test_checkpoint_directory_write(checkpoint_dir: str) -> None: + """ + Test that checkpoint directory exists and is writable. + Raises an error and exits if write test fails. + + Args: + checkpoint_dir: Path to checkpoint directory to test + """ + try: + # Create directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # Test write access + test_file = os.path.join(checkpoint_dir, ".write_test") + test_content = f"Write test at {datetime.datetime.now().isoformat()}\n" + + # Write test file + with open(test_file, 'w') as f: + f.write(test_content) + + # Read back to verify + with open(test_file, 'r') as f: + read_content = f.read() + + if read_content != test_content: + raise IOError(f"Write test failed: content mismatch") + + # Clean up test file + os.remove(test_file) + + print(f"✓ Checkpoint directory write test passed: {checkpoint_dir}") + + except Exception as e: + error_msg = ( + f"\n{'='*80}\n" + f"❌ CHECKPOINT DIRECTORY WRITE TEST FAILED\n" + f"{'='*80}\n" + f"Directory: {checkpoint_dir}\n" + f"Error: {str(e)}\n" + f"\nThis usually means:\n" + f" - S3 mount is not available at /s3\n" + f" - Insufficient permissions to write to S3\n" + f" - Disk space is full\n" + f"\nPlease check:\n" + f" 1. Docker volume mount: /mnt/s3/phd-research-storage-1758274488:/s3:rw\n" + f" 2. S3 mount is accessible: ls -la /s3\n" + f" 3. Write permissions on S3 mount\n" + f"{'='*80}\n" + ) + print(error_msg, file=sys.stderr) + sys.exit(1) + + +class EvalRunner(ABC): + """ + Protocol for evaluation runner classes. + + Classes implementing this protocol should have a `run_with_trainer()` method that: + - Takes trainer and wandb_run as parameters + - Uses the trainer's model directly (already loaded on GPU) + - Returns a pandas DataFrame containing evaluation results + - Each row should contain evaluation metrics (e.g., score, length, etc.) + """ + + @abstractmethod + def run_with_trainer(self, trainer: Any, wandb_run: Optional[Any] = None) -> pd.DataFrame: + """ + Run evaluation using the trainer's model directly. + + Args: + trainer: Trainer instance with model already loaded + wandb_run: Optional wandb run object for logging + + Returns: + DataFrame with evaluation results, each row containing evaluation metrics + """ + ... + + +class EvalCallback(TrainerCallback): + """ + Minimal training callback that runs evaluation during training and logs results to wandb. + + Uses the trainer's in-memory model directly (no need to reload from disk). + Delegates all evaluation logic to the provided evaluator class. + """ + + def __init__(self, evaluator: EvalRunner): + """ + Initialize the evaluation callback. + + Args: + evaluator: An object with a `run_with_trainer(trainer, wandb_run)` method that + performs evaluation using the trainer's model and returns a DataFrame + """ + self.evaluator = evaluator + self.last_eval_step = -1 + self._trainer_ref = None # Will be set when trainer is available + + def set_trainer(self, trainer: Any): + """Set trainer reference for callbacks that need it.""" + self._trainer_ref = trainer + + def on_evaluate(self, args, state, control, **kwargs): + """Run evaluation when trainer evaluation is triggered.""" + print(f"[EvalCallback] on_evaluate called at step {state.global_step}") + + # Avoid duplicate evals at the same step + if state.global_step == self.last_eval_step: + print(f"[EvalCallback] Already evaluated at step {state.global_step}, skipping") + return + + self.last_eval_step = state.global_step + + try: + # Get trainer from kwargs - transformers TrainerCallback passes 'model' and sometimes 'trainer' + # For GRPOTrainer, we need to access it via the callback's parent reference or kwargs + trainer = kwargs.get('trainer') + if trainer is None: + # Try to get model directly - GRPOTrainer might pass model in kwargs + model = kwargs.get('model') + if model is not None: + # If we have model but not trainer, we need to find trainer another way + # Actually, let's check if we can store a reference to trainer in __init__ + print("[EvalCallback] Warning: Trainer not found in kwargs, checking callback context") + # Fallback: use self if callback was passed trainer reference + if hasattr(self, '_trainer_ref'): + trainer = self._trainer_ref + else: + print("[EvalCallback] Warning: Cannot access trainer, skipping evaluation") + return + + print(f"[EvalCallback] Running evaluation at step {state.global_step} (using model from trainer)") + + # Get wandb run object and URL + wandb_run = wandb.run + if wandb_run: + wandb_url = wandb_run.url + print(f"[EvalCallback] W&B Run URL: {wandb_url}") + + # Run evaluation using the trainer's model directly + results_df = self.evaluator.run_with_trainer(trainer, wandb_run) + + if results_df is None or len(results_df) == 0: + print("[EvalCallback] Warning: Evaluation returned no results") + return + + # Log results to wandb + self._log_results(results_df, state.global_step) + + print(f"[EvalCallback] Logged evaluation results for step {state.global_step}") + + except Exception as e: + print(f"[EvalCallback] Error during evaluation: {e}") + import traceback + traceback.print_exc() + + def _get_checkpoint_path(self, args, state) -> Optional[str]: + """ + Get the path to the current checkpoint. + + Args: + args: Training arguments + state: Trainer state + + Returns: + Path to checkpoint directory, or base model path if checkpoint doesn't exist yet + """ + # Check if there's a checkpoint directory for this step + checkpoint_dir = f"{args.output_dir}/checkpoint-{state.global_step}" + if os.path.exists(checkpoint_dir) and os.path.exists(f"{checkpoint_dir}/config.json"): + return checkpoint_dir + + # Fallback: use output_dir if checkpoint doesn't exist yet + # Check if output_dir has config.json (meaning it's a valid checkpoint) + if os.path.exists(args.output_dir) and os.path.exists(f"{args.output_dir}/config.json"): + return args.output_dir + + # If no checkpoint exists yet, return None - evaluator will use base model + return None + + def _log_results(self, results_df: pd.DataFrame, step: int) -> None: + """ + Log evaluation results to wandb as both table and artifact. + + Args: + results_df: DataFrame containing evaluation results + step: Current training step + """ + if results_df is None or len(results_df) == 0: + return + + df = results_df + + # Log as wandb table (for quick dashboard viewing) + wandb.log({ + "eval/step": step, + "eval/results_table": wandb.Table(dataframe=df), + }) + + # Log summary statistics for numeric columns + numeric_cols = df.select_dtypes(include=['number']).columns + stats = {} + for col in numeric_cols: + if col != "step": + stats[f"eval/stats/{col}/mean"] = float(df[col].mean()) + stats[f"eval/stats/{col}/std"] = float(df[col].std()) + stats[f"eval/stats/{col}/min"] = float(df[col].min()) + stats[f"eval/stats/{col}/max"] = float(df[col].max()) + stats[f"eval/stats/{col}/median"] = float(df[col].median()) + + if stats: + wandb.log(stats) + + # Create and log artifact (for full data export) + artifact_name = f"eval_results_step_{step}" + artifact = wandb.Artifact( + name=artifact_name, + type="evaluation_results", + description=f"Evaluation results at training step {step}", + metadata={ + "step": step, + "total_samples": len(df), + } + ) + + # Add CSV to artifact + with artifact.new_file("results.csv", mode="w") as f: + df.to_csv(f, index=False) + + # Log artifact + wandb.log_artifact(artifact) + diff --git a/sweeps/configs/sweep_config_training_with_eval.yaml b/sweeps/configs/sweep_config_training_with_eval.yaml new file mode 100644 index 0000000..61f1bfe --- /dev/null +++ b/sweeps/configs/sweep_config_training_with_eval.yaml @@ -0,0 +1,144 @@ +# W&B Sweep Configuration for PlasmidRL GRPO - Training + Length Reward + Eval +# Focus: Broad training hyperparameter search with length-based rewards and evaluation +# Duration: 500 steps per trial for stable evaluation +# Includes: Evaluation callback to track detailed sequence analysis +# +# Strategy: Test 2 length configurations across full hyperparameter space +# +# Usage: +# 1. Initialize the sweep: +# wandb sweep sweeps/configs/sweep_config_training_with_eval.yaml +# +# 2. Copy the sweep ID from output (e.g., mcclain/plasmidrl-grpo-sweeps/abc123xyz) +# +# 3. Run agent(s): +# SWEEP_ID=mcclain/plasmidrl-grpo-sweeps/abc123xyz docker compose up grpo-sweep +# +# Monitor progress at: https://wandb.ai/mcclain/plasmidrl-grpo-sweeps + +program: sweeps/run_sweep_agent.py +method: bayes +metric: + name: reward_components/total_reward/mean + goal: maximize + +parameters: + # ==================== TRAINING HYPERPARAMETERS (BROAD SEARCH) ==================== + + max_steps: + value: 500 # Longer runs for stable evaluation + + # Learning rate - broad exploration + learning_rate: + distribution: log_uniform_values + min: 1e-6 + max: 1e-4 + + # Batch size + per_device_train_batch_size: + values: [8, 16, 32] + + # Generations per batch + num_generations: + values: [4, 8, 16] + + # Temperature - sampling randomness + temperature: + distribution: uniform + min: 0.7 + max: 1.4 + + # Top-p - nucleus sampling + top_p: + distribution: uniform + min: 0.85 + max: 0.95 + + # Beta - KL penalty coefficient + beta: + distribution: log_uniform_values + min: 1e-4 + max: 1e-2 + + # Epsilon - PPO-style clipping + epsilon: + distribution: uniform + min: 0.1 + max: 0.3 + + # ==================== LENGTH REWARD PARAMETERS (2 COMBOS) ==================== + + reward_length_reward_mode: + value: true # Enable length-based rewards + + # Test 2 length configurations + # Combo 1: min=2000, ideal_min=3000, ideal_max=12000, max=15000 + # Combo 2: min=5000, ideal_min=7000, ideal_max=20000, max=30000 + + reward_min_length: + values: [2000, 5000] + + reward_max_length: + values: [15000, 30000] + + reward_ideal_min_length: + values: [3000, 7000] + + reward_ideal_max_length: + values: [12000, 20000] + + # Bonus multiplier for being in ideal range + reward_length_reward_bonus: + distribution: uniform + min: 0.3 + max: 0.8 + + # ==================== REWARD CONFIG (FIXED - STANDARD) ==================== + + reward_punish_mode: + value: true + + reward_promoter_max: + value: 5 + + reward_terminator_max: + value: 2 + + reward_marker_max: + value: 2 + + reward_cds_max: + value: 5 + + reward_location_aware: + value: true + + # Reward component weights (all enabled with standard weights) + reward_ori_weight: + value: 1.0 + + reward_promoter_weight: + value: 1.0 + + reward_terminator_weight: + value: 0.5 + + reward_marker_weight: + value: 1.0 + + reward_cds_weight: + value: 1.0 + +# Early termination - adjusted for longer runs +early_terminate: + type: hyperband + min_iter: 100 # Higher threshold for 500-step runs + eta: 3 + s: 2 + +# Optional: Limit total number of runs +# run_cap: 50 + + + +