Multi-architecture Graph Neural Network comparison for drug-disease link prediction using the PrimeKG knowledge graph. Compares 6 GNN architectures (RGCN, GCN, GAT, GraphSAGE, GIN) and an MLP baseline with strict evaluation under hard negative sampling.
This project implements a complete pipeline for biomedical link prediction, from data preprocessing to advanced analysis. Multiple GNN encoders learn from drug-gene and gene-disease relationships to predict potential therapeutic indications, with strict evaluation ensuring no data leakage and hard negative sampling for reliable metrics.
- Multi-model comparison: 6 architectures (RGCN, GCN, GAT, GraphSAGE, GIN, MLP) with unified training and evaluation
- Strict evaluation: Hard negative sampling (50 negatives/positive) with data-leakage-free splitting
- Modular design: Plug-and-play encoder registry — easy to add new GNN architectures
- Comprehensive evaluation: classification metrics, ranking metrics, and error analysis
- Medical validation: biological plausibility checking and evidence gathering
- Drug repurposing: disease-specific case studies with pathway analysis
- Interpretable predictions: path-based explanations with natural language generation
- Embedding analysis: t-SNE/UMAP visualization and clustering
- GPU accelerated: optimized for fast inference and batch processing (CUDA / MPS / CPU)
Drug --[interacts]--> Gene --[associated]--> Disease
The model learns to predict drug-disease indications by:
- Encoding drug-gene and gene-disease relationships
- Learning node embeddings via relational graph convolutions
- Predicting links between drug and disease nodes
This project uses PrimeKG, a precision medicine knowledge graph containing:
- 4.5 million relationships
- 20 different data sources
- 129,375 nodes (drugs, diseases, genes, proteins, etc.)
- Multiple relation types including:
- Drug-Gene interactions
- Gene-Gene interactions
- Gene-Disease associations
Our processed graph:
- 30,926 nodes (6,282 drugs, 5,593 diseases, 19,051 genes/proteins)
- 849,456 edges (3 relation types)
- Train/Val/Test split: 70% / 15% / 15%
Reference: Chandak et al., "Building a knowledge graph to enable precision medicine." Nature Scientific Data (2023).
- Python 3.8+
- CUDA 11.0+ (optional, for GPU acceleration)
- Clone the repository
git clone https://github.com/arnold117/PrimeKG-RGCN-LinkPrediction.git
cd PrimeKG-RGCN-LinkPrediction- Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies
pip install -r requirements.txtKey packages:
torch>=2.0.0- PyTorch deep learning frameworktorch-geometric>=2.3.0- Graph neural networksnetworkx>=2.8- Graph algorithmspandas>=2.0.0- Data manipulationmatplotlib>=3.7.0- Visualizationseaborn>=0.12.0- Statistical visualizationscikit-learn>=1.3.0- Machine learning utilitiestqdm>=4.65.0- Progress bars
Optional:
plotly>=5.14.0- Interactive visualizationsumap-learn>=0.5.3- UMAP dimensionality reduction
Previous versions had a data leakage issue where 71.5% of test edges' reverse directions existed in training data. This has been fixed with undirected-edge-aware splitting, ensuring 0% data leakage.
Strict Evaluation (hard negative sampling, 50 negatives per positive):
| Rank | Model | AUC-ROC | AP | Hits@10 | Hits@50 | MRR |
|---|---|---|---|---|---|---|
| 1 | GAT | 0.9866 | 0.7955 | 0.9831 | 1.0000 | 0.9031 |
| 2 | RGCN | 0.9794 | 0.6697 | 0.9789 | 1.0000 | 0.8699 |
| 3 | GIN | 0.9774 | 0.6522 | 0.9757 | 1.0000 | 0.8613 |
| 4 | GraphSAGE | 0.9742 | 0.6394 | 0.9789 | 1.0000 | 0.8862 |
| 5 | GCN | 0.9657 | 0.5538 | 0.9724 | 1.0000 | 0.8640 |
| 6 | MLP | 0.6592 | 0.0314 | 0.5234 | 0.9874 | 0.2710 |
Key Findings:
- Graph structure is critical: All GNNs (AUC-ROC > 0.96) vastly outperform MLP baseline (0.66)
- Attention mechanism wins: GAT achieves best AP (0.7955), 18.8% higher than second-place RGCN
- Relation types help moderately: RGCN > GCN (0.9794 vs 0.9657), but less impactful than attention
- Theoretical expressiveness ≠ task performance: GIN (WL-test equivalent) ranks 3rd, not 1st
Data Split:
- Train: 838,882 edges (35,910 drug-gene + 802,972 other)
- Val: 7,688 edges
- Test: 7,708 edges
- Data Leakage: 0% ✓
| Parameter | Value |
|---|---|
| Epochs | 50 |
| Batch Size | 2048 |
| Learning Rate | 0.001 |
| Hidden Dim | 128 |
| Embedding Dim | 64 |
| Dropout | 0.5 |
| Negative Samples | 1 |
| Decoder | DistMult |
| Device | CUDA / MPS / CPU (auto-detect) |
PrimeKG-RGCN-LinkPrediction/
├── data/
│ ├── processed/ # Preprocessed graph data
│ │ ├── full_graph.pt # Complete knowledge graph
│ │ ├── train_data.pt # Training edges (70%)
│ │ ├── val_data.pt # Validation edges (15%)
│ │ ├── test_data.pt # Test edges (15%)
│ │ ├── mappings.pt # Node/relation mappings
│ │ └── statistics.csv # Dataset statistics
│ └── raw/ # Original PrimeKG data (download separately)
│
├── src/
│ ├── run_full_analysis.py # Main entry point for all analyses
│ ├── train.py # Multi-model training (--model flag)
│ ├── strict_evaluation.py # Strict eval with hard negatives
│ ├── evaluate.py # Basic evaluation metrics
│ ├── analyze_results.py # Advanced result analysis
│ ├── error_analysis.py # Error pattern analysis
│ ├── case_studies.py # Disease-specific predictions
│ ├── visualize_embeddings.py # Embedding visualization
│ ├── explain_predictions.py # Path-based explanations
│ ├── medical_validation.py # Biological validation
│ ├── compare_methods.py # Baseline comparison
│ ├── analyze_failures.py # Failure mode analysis
│ ├── data/
│ │ └── preprocess.py # Data preprocessing
│ └── models/ # Modular encoder registry
│ ├── __init__.py # Model registry & get_encoder()
│ ├── rgcn.py # Relational GCN (uses edge types)
│ ├── gcn.py # Standard GCN
│ ├── gat.py # Graph Attention Network
│ ├── graphsage.py # GraphSAGE
│ ├── gin.py # Graph Isomorphism Network
│ └── mlp.py # MLP baseline (no graph)
│
├── output/ # Training outputs (per-model subdirectories)
│ ├── rgcn/ # RGCN outputs
│ │ ├── models/ # best_model.pt, final_model.pt
│ │ ├── checkpoints/ # Periodic checkpoints
│ │ └── analysis/ # Strict evaluation results
│ ├── gat/ # GAT outputs (same structure)
│ ├── gcn/ # GCN outputs
│ ├── graphsage/ # GraphSAGE outputs
│ ├── gin/ # GIN outputs
│ └── mlp/ # MLP outputs
│
├── results/ # Evaluation results (Best model)
│ ├── results.json # Basic metrics
│ ├── analysis/ # Advanced analysis
│ ├── case_studies/ # Disease-specific studies
│ ├── embeddings/ # Embedding visualizations
│ ├── error_analysis/ # Error patterns
│ ├── explanations/ # Prediction explanations
│ ├── validation/ # Medical validation
│ ├── comparison/ # Method comparison
│ └── failure_analysis/ # Failure mode analysis
│
├── results_final/ # Final results (Final model)
│ └── [same structure as results/]
│
├── checkpoints/ # Legacy checkpoints (if any)
├── models/ # Legacy models (if any)
│
├── requirements.txt # Python dependencies
├── README.md # This file
├── guide/ # Collection of script guides
│ ├── README.md
│ ├── TRAINING_GUIDE.md
│ ├── PREPROCESS_GUIDE.md
│ ├── EVALUATION_GUIDE.md
│ ├── CASE_STUDIES_GUIDE.md
│ ├── MEDICAL_VALIDATION_GUIDE.md
│ ├── METHOD_COMPARISON_GUIDE.md
│ ├── RUN_FULL_ANALYSIS_GUIDE.md
│ ├── EXPLAIN_PREDICTIONS_GUIDE.md
│ ├── ANALYZE_FAILURES_GUIDE.md
│ ├── ERROR_ANALYSIS_GUIDE.md
│ ├── VISUALIZE_EMBEDDINGS_GUIDE.md
│ └── MODEL_ARCHITECTURE.md
# Train RGCN (default)
python src/train.py --model rgcn --epochs 50 --output_dir output/rgcn
# Train GAT (best performing)
python src/train.py --model gat --epochs 50 --output_dir output/gat
# Available models: rgcn, gcn, gat, graphsage, gin, mlp
python src/train.py --model <model_name> --epochs 50 --output_dir output/<model_name>
# Run on server (continues after terminal disconnect)
nohup python src/train.py --model gat --epochs 50 --output_dir output/gat > output/gat/train.log 2>&1 &Detailed guide: guide/TRAINING_GUIDE.md
Use the main entry point to run comprehensive analysis:
# Run all analyses on best model
python src/run_full_analysis.py \
--model_path output/models/best_model.pt \
--output_dir results
# Run all analyses on final model (for paper)
python src/run_full_analysis.py \
--model_path output/models/final_model.pt \
--output_dir results_final
# Run specific analyses only
python src/run_full_analysis.py \
--analyses evaluate case_studies \
--model_path output/models/best_model.pt
# See all options
python src/run_full_analysis.py --helpAvailable Analyses:
- evaluate - Basic metrics (AUC-ROC, Hits@K, MRR)
- errors - Error pattern analysis
- case_studies - Disease-specific predictions
- embeddings - Embedding visualization
- explanations - Path-based explanations
- validation - Medical validation
- comparison - Baseline comparison
- failures - Failure mode analysis
# Run complete final analysis (saves to results_final/)
./run_final_analysis.sh
# Customize in the script:
# - Model path
# - Output directory
# - Specific analyses to run# Train any model
python src/train.py --model <model_name> --epochs 50 --output_dir output/<model_name>
# With memory optimization
python src/train.py --model rgcn --batch_size 256 --gradient_accumulation_steps 4Training outputs (per model in output/<model_name>/):
models/best_model.pt: Best validation performancemodels/final_model.pt: Last epochcheckpoints/: Periodic checkpoints
Detailed guide: guide/TRAINING_GUIDE.md
# Evaluate with hard negative sampling (50 negatives per positive)
python src/strict_evaluation.py \
--model gat \
--checkpoint output/gat/models/best_model.pt \
--output_dir output/gat/analysis
# Custom number of negatives
python src/strict_evaluation.py --model rgcn --num_neg 100# Basic evaluation
python src/evaluate.py --model_path output/rgcn/models/best_model.pt
# Custom settings
python src/evaluate.py \
--model_path output/rgcn/models/best_model.pt \
--output_dir results \
--batch_size 512 \
--k_values 10 50 100Output files:
results.json: All metrics in JSON formatmetrics_summary.txt: Human-readable summaryconfusion_matrix.png: Confusion matrix heatmaproc_curve.png: ROC curve visualizationprecision_recall_curve.png: PR curvescore_distribution.png: Score distributions
Metrics computed:
- Classification: AUC-ROC, AUC-PR, Precision, Recall, F1
- Ranking: Hits@K, MRR, Mean/Median Rank
Detailed guide: guide/EVALUATION_GUIDE.md
# Comprehensive result analysis
python src/analyze_results.py --results_path results/results.json
# Compare multiple runs
python src/analyze_results.py \
--results_paths results/run1.json results/run2.json \
--labels "Run 1" "Run 2"Analysis includes:
- Performance breakdowns by node type
- Score distributions and calibration
- Confidence intervals
- Statistical comparisons
Note: This script may need to be adapted for your specific analysis needs.
# Analyze error patterns
python src/error_analysis.py --model_path output/models/best_model.pt
# Focus on specific error types
python src/error_analysis.py \
--model_path output/models/best_model.pt \
--threshold 0.8 \
--output_dir results/error_analysisIdentifies:
- False positive patterns
- False negative patterns
- Confidence-error relationships
- Graph structure issues
Detailed guide: guide/ERROR_ANALYSIS_GUIDE.md
# Analyze top drug predictions for a disease
python src/case_studies.py --disease "Type 2 Diabetes" --top_k 10
# With confidence threshold
python src/case_studies.py \
--disease "Alzheimer disease" \
--top_k 20 \
--threshold 0.7
# GPU-accelerated for faster inference
python src/case_studies.py \
--disease "cancer" \
--model_path output/models/final_model.pt \
--output_dir results_final/case_studiesGenerates:
- Case study report: Top predictions with biological insights
- Prediction scores plot: Bar chart of confidence scores
- Network diagrams: Drug-gene-disease pathways
- JSON export: Machine-readable results
Features:
- Known vs novel predictions
- Connection paths through genes
- Mechanistic interpretations
- Medical recommendations
Detailed guide: guide/CASE_STUDIES_GUIDE.md
# Visualize learned embeddings
python src/visualize_embeddings.py --model_path output/models/best_model.pt
# Sample fewer nodes for faster visualization
python src/visualize_embeddings.py \
--model_path output/models/best_model.pt \
--sample_size 5000 \
--method tsne
# With clustering analysis
python src/visualize_embeddings.py \
--model_path output/models/best_model.pt \
--cluster \
--n_clusters 8Visualizations:
- t-SNE/UMAP projections colored by node type
- Clustering analysis with silhouette scores
- Distance matrices (drug-drug, disease-disease, drug-disease)
- Nearest neighbor analysis
- Interactive HTML plots (optional)
Detailed guide: guide/VISUALIZE_EMBEDDINGS_GUIDE.md
# Explain a specific prediction
python src/explain_predictions.py \
--drug "Metformin" \
--disease "diabetes mellitus" \
--top_k 5
# Batch explanation for multiple pairs
python src/explain_predictions.py \
--drug "Aspirin" \
--disease "heart disease" \
--top_k 10Generates:
- Path-based explanations: Drug → Gene → Gene → Disease
- Natural language summaries: Human-readable explanations
- Network visualizations: Graph showing paths and importance
- Path ranking: Top-K most important paths
- Sankey diagrams: Flow visualization (optional)
Example output:
"Metformin may treat diabetes mellitus through a pathway
involving PRKAB1, PRKAA2, and RFX6. This connection suggests
a 4-step mechanism linking the drug's molecular targets to
the disease pathology."
Detailed guide: guide/EXPLAIN_PREDICTIONS_GUIDE.md
# Validate top novel predictions
python src/medical_validation.py --top_k 50
# Custom threshold and sampling
python src/medical_validation.py \
--top_k 100 \
--threshold 0.7 \
--sample_diseases 100Validation criteria:
- Drug targets disease-related genes
- Common gene neighbors exist
- Literature evidence found (mock)
- Clinical trials exist (mock)
- Multiple connecting pathways
Outputs:
- Validation report: High/medium/low confidence predictions
- CSV export: Detailed scores and checklists
- Validation scores: Weighted assessment (0-1)
Detailed guide: guide/MEDICAL_VALIDATION_GUIDE.md
# Compare with baselines
python src/compare_methods.py --methods random degree rgcn
# Include TransE baseline
python src/compare_methods.py \
--methods random degree transe rgcn \
--transe_epochs 50
# Full analysis with all plots
python src/compare_methods.py \
--methods random degree rgcn \
--frequency_analysis \
--statistical_testsBaselines:
- Random: Random predictions (lower bound)
- Node Degree: Popularity-based predictions
- TransE: Translation-based embeddings
- RGCN: Your model
Outputs:
- Comparison bar charts for all metrics
- Performance by disease frequency
- Statistical significance heatmap
- LaTeX/Markdown tables for papers
Detailed guide: guide/METHOD_COMPARISON_GUIDE.md
# Deep dive into prediction failures
python src/analyze_failures.py --num_failures 5 --num_successes 5
# With subgraph visualizations
python src/analyze_failures.py \
--num_failures 10 \
--num_successes 10 \
--visualize_subgraphs \
--num_samples 10000Analysis:
- Identifies worst predictions (high confidence but wrong)
- Compares with correct predictions
- Visualizes subgraphs around failures
- Generates failure hypotheses
- Suggests model improvements
Example findings:
- "Model fails when there are FEW CONNECTING PATHS (0.4 vs 7.6)"
- "Model makes more FALSE POSITIVES due to high-degree nodes"
- "Failures occur in SPARSE NEIGHBORHOODS"
Suggestions:
- Add attention mechanisms
- Increase GCN layers
- Add negative sampling
- Use degree normalization
Detailed guide: guide/ANALYZE_FAILURES_GUIDE.md
Each analysis creates a subdirectory in results_final/:
results_final/
├── evaluation/ # Basic metrics (AUC, Hits@K, MRR)
├── analysis/ # Performance by disease frequency, relation type
├── error_analysis/ # False positive/negative patterns
├── case_studies/ # Top predictions for specific diseases
├── embeddings/ # t-SNE/UMAP visualizations, clusters
├── explanations/ # Path-based reasoning for predictions
├── validation/ # Biological plausibility scores
├── comparison/ # Baseline method comparisons
└── failures/ # Failure case deep-dives
| File | Purpose | When to Use |
|---|---|---|
evaluation_results.csv |
Basic metrics overview | First check of model performance |
performance_by_disease_frequency.png |
Bias analysis | Check if model favors common diseases |
error_patterns.txt |
Common failure modes | Understand systematic errors |
alzheimers_predictions.csv |
Disease-specific predictions | Validate domain knowledge |
embedding_clusters_report.txt |
Entity clustering | Discover entity groupings |
explanation_summary.txt |
Top explanations | Understand model reasoning |
validation_report.txt |
Biological validation | Assess medical plausibility |
paper_table_latex.txt |
Method comparison table | Include in publications |
failure_analysis_report.txt |
Error hypotheses | Guide model improvements |
-
Quick Assessment: Run evaluation + case studies
python src/run_full_analysis.py \ --model_path output/models/final_model.pt \ --output_dir results_final \ --analyses evaluate case_studies
-
Deep Dive: Add error analysis + explanations for specific insights
-
Publication: Run full analysis suite for comprehensive reporting
-
Model Improvement: Focus on failures + validation to guide next iteration
Contributions are welcome! Here are some areas for improvement:
-
Real Medical Validation
- Integrate with PubMed API for literature validation
- Connect to ClinicalTrials.gov for trial data
- Add DrugBank integration
-
Advanced Baselines
- Implement ComplEx, RotatE embeddings
- Include rule-based methods
-
Interpretability
- GAT attention weight visualization
- GNNExplainer integration
- Counterfactual explanations
-
Scalability
- Distributed training support
- Mini-batch sampling for large graphs
- Model compression techniques
- Fork the repository
- Create a feature branch (
git checkout -b feature/AmazingFeature) - Make your changes
- Add tests if applicable
- Commit your changes (
git commit -m 'Add AmazingFeature') - Push to the branch (
git push origin feature/AmazingFeature) - Open a Pull Request
- Follow PEP 8 guidelines
- Add docstrings to all functions
- Include type hints where possible
- Write descriptive commit messages
If you use this code in your research, please cite:
@software{primekg_rgcn_2025,
author = {arnold117},
title = {PrimeKG-RGCN-LinkPrediction: Drug-Disease Link Prediction with Relational Graph Convolutional Networks},
year = {2025},
url = {https://github.com/arnold117/PrimeKG-RGCN-LinkPrediction}
}And the PrimeKG dataset:
@article{chandak2023building,
title={Building a knowledge graph to enable precision medicine},
author={Chandak, Payal and Huang, Kexin and Zitnik, Marinka},
journal={Nature Scientific Data},
volume={10},
number={1},
pages={67},
year={2023},
publisher={Nature Publishing Group}
}This project is licensed under the MIT License - see the LICENSE file for details.
- PrimeKG Team at Harvard Medical School for the knowledge graph
- PyTorch Geometric team for the GNN framework
- Open-source community for various analysis tools
For questions or issues:
- Open an issue on GitHub
Q: How long does training take?
A: ~4-5 hours on a single GPU (RTX 1070) for 100 epochs.
Q: Can I use this on my own knowledge graph?
A: Yes! Modify src/data/preprocess.py to load your data format.
Q: What GPU memory is required?
A: Minimum 2GB. Our model uses less than 1GB during training.
Q: How do I handle OOM errors?
A: Reduce batch size or hidden dimensions in src/train.py.
Q: Can I add more baseline methods?
A: Yes! Extend the BaselineMethod class in src/compare_methods.py.
Q: How accurate are the medical validations?
A: Current implementation uses proxy signals. For production use, integrate real biomedical databases.