This project integrates Reinforcement Learning (RL), specifically Soft Actor-Critic (SAC), with Large Language Models (LLMs) like OpenAI's GPT series. The core idea is to leverage LLMs to automatically generate and iteratively refine reward functions for robotic tasks, aiming to improve learning efficiency and final policy performance in simulated environments like FetchReach.
- RL Algorithm: Utilizes Stable Baselines3.
- LLM Integration: Connects with OpenAI's GPT and Groq API models to generate and refine reward functions based on task descriptions and potentially environment feedback.
- Environment Support: Currently configured for the
FetchReach-v3environment fromgymnasium-robotics. - Configuration: Uses Hydra for flexible and structured configuration management.
- Logging: Includes TensorBoard logging for visualizing training progress (e.g., rewards, losses).
- Iterative Refinement: Supports multiple iterations where the reward function is updated by the LLM based on previous results.
-
Clone the Repository:
git clone https://github.com/PCDorg/RewGen4Robotics.git cd RewGen4Robotics -
Create and Activate Virtual Environment (Recommended):
python -m venv venv source venv/bin/activate # Linux/macOS # venv\Scripts\activate # Windows
-
Install Dependencies:
pip install -r requirements.txt
(Note: If
requirements.txtdoesn't exist, you'll need to create one based on the project's imports.) -
Set API Keys:
export OPENAI_API_KEY='your-actual-openai-api-key' export GROQ_API_KEY='your-actual-groq-api-key'
(Consider using environment variables or a
.envfile for better security.)
Configuration is managed via Hydra. Key configuration files are typically located in a cfg/ directory:
cfg/config.yaml: Main configuration file linking others.cfg/training/: Training parameters (algorithm hyperparameters, total timesteps).cfg/env/: Environment settings (environment ID).cfg/llm/: LLM parameters (model names for OpenAI/Groq, prompts).cfg/experiment/: Experiment settings (number of iterations, logging paths).cfg/paths/: Defines paths for outputs, models, logs.
You can override parameters via the command line:
python gen.py training.total_timesteps=50000 llm.model="llama3-70b-8192" llm.provider="groq"Iterative Training with LLM (gen.py)
The main script (gen.py) orchestrates the iterative process:
-
Run with Default Configuration:
# Ensure correct environment is set via defaults or command line python gen.py env=go1 # Or env=fetchReach, etc.
-
Run with Overrides:
python gen.py experiment.iterations=5 training.learning_rate=1e-4 llm.provider=openai llm.model=gpt-4
The script typically performs the following loop for the specified number of iterations:
- Generates/Refines a reward function using the LLM (OpenAI or Groq).
- Updates the environment wrapper or reward logic.
- Trains the SAC agent using the current reward function.
- Saves the trained model, logs, and generated reward function.
- Potentially uses results from the training run to inform the next LLM generation step.
.
├── cfg/ # Hydra configuration files
│ ├── config.yaml # Main config
│ ├── training/
│ ├── env/
│ ├── llm/
│ ├── experiment/
│ └── paths/
├── envs/ # Custom environment wrappers (if any)
├── models/ # Saved RL models per iteration
├── results/ # Logs, generated rewards, plots
│ └── reach/ # Example output for 'reach' task
│ ├── iteration_1/
│ ├── iteration_2/
│ └── ...
├── utils/ # Utility functions (e.g., LLM interaction)
├── gen.py # Main script
├── requirements.txt # Project dependencies
└── README.md # This file
During execution, the project generates:
- Saved Models: RL agent models saved periodically or at the end of each iteration (in
models/orresults/.../models/). - TensorBoard Logs: Training metrics logged for TensorBoard (in
results/.../tensorboard_logs/). - Generated Rewards: Text files or logs containing the prompts and the reward functions generated by the LLM for each iteration (in
results/.../rewards/). - Console Logs: Output detailing the progress of each iteration, training steps, and LLM interactions.
Launch TensorBoard to view learning curves and other metrics:
# Adjust the logdir path based on your output structure
tensorboard --logdir results/reach/tensorboard_logsNavigate to http://localhost:6006 in your browser.
- This project utilizes models and APIs from OpenAI and Groq.
- Relies on the Stable Baselines3 library for RL implementations.
- Uses Hydra for configuration management.
- Uses environments from Gymnasium and Gymnasium-Robotics.