Skip to content

arnold117/PrimeKG-RGCN-LinkPrediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PrimeKG-RGCN-LinkPrediction

Multi-architecture Graph Neural Network comparison for drug-disease link prediction using the PrimeKG knowledge graph. Compares 6 GNN architectures (RGCN, GCN, GAT, GraphSAGE, GIN) and an MLP baseline with strict evaluation under hard negative sampling.

Overview

This project implements a complete pipeline for biomedical link prediction, from data preprocessing to advanced analysis. Multiple GNN encoders learn from drug-gene and gene-disease relationships to predict potential therapeutic indications, with strict evaluation ensuring no data leakage and hard negative sampling for reliable metrics.

Key Features

  • Multi-model comparison: 6 architectures (RGCN, GCN, GAT, GraphSAGE, GIN, MLP) with unified training and evaluation
  • Strict evaluation: Hard negative sampling (50 negatives/positive) with data-leakage-free splitting
  • Modular design: Plug-and-play encoder registry — easy to add new GNN architectures
  • Comprehensive evaluation: classification metrics, ranking metrics, and error analysis
  • Medical validation: biological plausibility checking and evidence gathering
  • Drug repurposing: disease-specific case studies with pathway analysis
  • Interpretable predictions: path-based explanations with natural language generation
  • Embedding analysis: t-SNE/UMAP visualization and clustering
  • GPU accelerated: optimized for fast inference and batch processing (CUDA / MPS / CPU)

Methodology

Drug --[interacts]--> Gene --[associated]--> Disease

The model learns to predict drug-disease indications by:

  1. Encoding drug-gene and gene-disease relationships
  2. Learning node embeddings via relational graph convolutions
  3. Predicting links between drug and disease nodes

Dataset

This project uses PrimeKG, a precision medicine knowledge graph containing:

  • 4.5 million relationships
  • 20 different data sources
  • 129,375 nodes (drugs, diseases, genes, proteins, etc.)
  • Multiple relation types including:
    • Drug-Gene interactions
    • Gene-Gene interactions
    • Gene-Disease associations

Our processed graph:

  • 30,926 nodes (6,282 drugs, 5,593 diseases, 19,051 genes/proteins)
  • 849,456 edges (3 relation types)
  • Train/Val/Test split: 70% / 15% / 15%

Reference: Chandak et al., "Building a knowledge graph to enable precision medicine." Nature Scientific Data (2023).

Installation

Prerequisites

  • Python 3.8+
  • CUDA 11.0+ (optional, for GPU acceleration)

Setup

  1. Clone the repository
git clone https://github.com/arnold117/PrimeKG-RGCN-LinkPrediction.git
cd PrimeKG-RGCN-LinkPrediction
  1. Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
  1. Install dependencies
pip install -r requirements.txt

Dependencies

Key packages:

  • torch>=2.0.0 - PyTorch deep learning framework
  • torch-geometric>=2.3.0 - Graph neural networks
  • networkx>=2.8 - Graph algorithms
  • pandas>=2.0.0 - Data manipulation
  • matplotlib>=3.7.0 - Visualization
  • seaborn>=0.12.0 - Statistical visualization
  • scikit-learn>=1.3.0 - Machine learning utilities
  • tqdm>=4.65.0 - Progress bars

Optional:

  • plotly>=5.14.0 - Interactive visualizations
  • umap-learn>=0.5.3 - UMAP dimensionality reduction

Model Performance

Data Leakage Fix (v2.0)

Previous versions had a data leakage issue where 71.5% of test edges' reverse directions existed in training data. This has been fixed with undirected-edge-aware splitting, ensuring 0% data leakage.

Multi-Model Comparison

Strict Evaluation (hard negative sampling, 50 negatives per positive):

Rank Model AUC-ROC AP Hits@10 Hits@50 MRR
1 GAT 0.9866 0.7955 0.9831 1.0000 0.9031
2 RGCN 0.9794 0.6697 0.9789 1.0000 0.8699
3 GIN 0.9774 0.6522 0.9757 1.0000 0.8613
4 GraphSAGE 0.9742 0.6394 0.9789 1.0000 0.8862
5 GCN 0.9657 0.5538 0.9724 1.0000 0.8640
6 MLP 0.6592 0.0314 0.5234 0.9874 0.2710

Key Findings:

  • Graph structure is critical: All GNNs (AUC-ROC > 0.96) vastly outperform MLP baseline (0.66)
  • Attention mechanism wins: GAT achieves best AP (0.7955), 18.8% higher than second-place RGCN
  • Relation types help moderately: RGCN > GCN (0.9794 vs 0.9657), but less impactful than attention
  • Theoretical expressiveness ≠ task performance: GIN (WL-test equivalent) ranks 3rd, not 1st

Data Split:

  • Train: 838,882 edges (35,910 drug-gene + 802,972 other)
  • Val: 7,688 edges
  • Test: 7,708 edges
  • Data Leakage: 0%

Training Configuration

Parameter Value
Epochs 50
Batch Size 2048
Learning Rate 0.001
Hidden Dim 128
Embedding Dim 64
Dropout 0.5
Negative Samples 1
Decoder DistMult
Device CUDA / MPS / CPU (auto-detect)

Project Structure

PrimeKG-RGCN-LinkPrediction/
├── data/
│   ├── processed/          # Preprocessed graph data
│   │   ├── full_graph.pt       # Complete knowledge graph
│   │   ├── train_data.pt       # Training edges (70%)
│   │   ├── val_data.pt         # Validation edges (15%)
│   │   ├── test_data.pt        # Test edges (15%)
│   │   ├── mappings.pt         # Node/relation mappings
│   │   └── statistics.csv      # Dataset statistics
│   └── raw/               # Original PrimeKG data (download separately)
│
├── src/
│   ├── run_full_analysis.py    # Main entry point for all analyses
│   ├── train.py                # Multi-model training (--model flag)
│   ├── strict_evaluation.py    # Strict eval with hard negatives
│   ├── evaluate.py             # Basic evaluation metrics
│   ├── analyze_results.py      # Advanced result analysis
│   ├── error_analysis.py       # Error pattern analysis
│   ├── case_studies.py         # Disease-specific predictions
│   ├── visualize_embeddings.py # Embedding visualization
│   ├── explain_predictions.py  # Path-based explanations
│   ├── medical_validation.py   # Biological validation
│   ├── compare_methods.py      # Baseline comparison
│   ├── analyze_failures.py     # Failure mode analysis
│   ├── data/
│   │   └── preprocess.py       # Data preprocessing
│   └── models/                 # Modular encoder registry
│       ├── __init__.py         # Model registry & get_encoder()
│       ├── rgcn.py             # Relational GCN (uses edge types)
│       ├── gcn.py              # Standard GCN
│       ├── gat.py              # Graph Attention Network
│       ├── graphsage.py        # GraphSAGE
│       ├── gin.py              # Graph Isomorphism Network
│       └── mlp.py              # MLP baseline (no graph)
│
├── output/                # Training outputs (per-model subdirectories)
│   ├── rgcn/              # RGCN outputs
│   │   ├── models/            # best_model.pt, final_model.pt
│   │   ├── checkpoints/       # Periodic checkpoints
│   │   └── analysis/          # Strict evaluation results
│   ├── gat/               # GAT outputs (same structure)
│   ├── gcn/               # GCN outputs
│   ├── graphsage/         # GraphSAGE outputs
│   ├── gin/               # GIN outputs
│   └── mlp/               # MLP outputs
│
├── results/               # Evaluation results (Best model)
│   ├── results.json           # Basic metrics
│   ├── analysis/              # Advanced analysis
│   ├── case_studies/          # Disease-specific studies
│   ├── embeddings/            # Embedding visualizations
│   ├── error_analysis/        # Error patterns
│   ├── explanations/          # Prediction explanations
│   ├── validation/            # Medical validation
│   ├── comparison/            # Method comparison
│   └── failure_analysis/      # Failure mode analysis
│
├── results_final/         # Final results (Final model)
│   └── [same structure as results/]
│
├── checkpoints/          # Legacy checkpoints (if any)
├── models/               # Legacy models (if any)
│
├── requirements.txt      # Python dependencies
├── README.md            # This file
├── guide/                # Collection of script guides
│   ├── README.md
│   ├── TRAINING_GUIDE.md
│   ├── PREPROCESS_GUIDE.md
│   ├── EVALUATION_GUIDE.md
│   ├── CASE_STUDIES_GUIDE.md
│   ├── MEDICAL_VALIDATION_GUIDE.md
│   ├── METHOD_COMPARISON_GUIDE.md
│   ├── RUN_FULL_ANALYSIS_GUIDE.md
│   ├── EXPLAIN_PREDICTIONS_GUIDE.md
│   ├── ANALYZE_FAILURES_GUIDE.md
│   ├── ERROR_ANALYSIS_GUIDE.md
│   ├── VISUALIZE_EMBEDDINGS_GUIDE.md
│   └── MODEL_ARCHITECTURE.md

Quick Start

1. Training

# Train RGCN (default)
python src/train.py --model rgcn --epochs 50 --output_dir output/rgcn

# Train GAT (best performing)
python src/train.py --model gat --epochs 50 --output_dir output/gat

# Available models: rgcn, gcn, gat, graphsage, gin, mlp
python src/train.py --model <model_name> --epochs 50 --output_dir output/<model_name>

# Run on server (continues after terminal disconnect)
nohup python src/train.py --model gat --epochs 50 --output_dir output/gat > output/gat/train.log 2>&1 &

Detailed guide: guide/TRAINING_GUIDE.md

2. Run All Analyses

Use the main entry point to run comprehensive analysis:

# Run all analyses on best model
python src/run_full_analysis.py \
    --model_path output/models/best_model.pt \
    --output_dir results

# Run all analyses on final model (for paper)
python src/run_full_analysis.py \
    --model_path output/models/final_model.pt \
    --output_dir results_final

# Run specific analyses only
python src/run_full_analysis.py \
    --analyses evaluate case_studies \
    --model_path output/models/best_model.pt

# See all options
python src/run_full_analysis.py --help

Available Analyses:

  1. evaluate - Basic metrics (AUC-ROC, Hits@K, MRR)
  2. errors - Error pattern analysis
  3. case_studies - Disease-specific predictions
  4. embeddings - Embedding visualization
  5. explanations - Path-based explanations
  6. validation - Medical validation
  7. comparison - Baseline comparison
  8. failures - Failure mode analysis

3. Convenient Shell Script

# Run complete final analysis (saves to results_final/)
./run_final_analysis.sh

# Customize in the script:
# - Model path
# - Output directory
# - Specific analyses to run

Usage Guide

Training

# Train any model
python src/train.py --model <model_name> --epochs 50 --output_dir output/<model_name>

# With memory optimization
python src/train.py --model rgcn --batch_size 256 --gradient_accumulation_steps 4

Training outputs (per model in output/<model_name>/):

  • models/best_model.pt: Best validation performance
  • models/final_model.pt: Last epoch
  • checkpoints/: Periodic checkpoints

Detailed guide: guide/TRAINING_GUIDE.md

Strict Evaluation

# Evaluate with hard negative sampling (50 negatives per positive)
python src/strict_evaluation.py \
    --model gat \
    --checkpoint output/gat/models/best_model.pt \
    --output_dir output/gat/analysis

# Custom number of negatives
python src/strict_evaluation.py --model rgcn --num_neg 100

Basic Evaluation

# Basic evaluation
python src/evaluate.py --model_path output/rgcn/models/best_model.pt

# Custom settings
python src/evaluate.py \
    --model_path output/rgcn/models/best_model.pt \
    --output_dir results \
    --batch_size 512 \
    --k_values 10 50 100

Output files:

  • results.json: All metrics in JSON format
  • metrics_summary.txt: Human-readable summary
  • confusion_matrix.png: Confusion matrix heatmap
  • roc_curve.png: ROC curve visualization
  • precision_recall_curve.png: PR curve
  • score_distribution.png: Score distributions

Metrics computed:

  • Classification: AUC-ROC, AUC-PR, Precision, Recall, F1
  • Ranking: Hits@K, MRR, Mean/Median Rank

Detailed guide: guide/EVALUATION_GUIDE.md

Advanced Result Analysis

# Comprehensive result analysis
python src/analyze_results.py --results_path results/results.json

# Compare multiple runs
python src/analyze_results.py \
    --results_paths results/run1.json results/run2.json \
    --labels "Run 1" "Run 2"

Analysis includes:

  • Performance breakdowns by node type
  • Score distributions and calibration
  • Confidence intervals
  • Statistical comparisons

Note: This script may need to be adapted for your specific analysis needs.

Error Analysis

# Analyze error patterns
python src/error_analysis.py --model_path output/models/best_model.pt

# Focus on specific error types
python src/error_analysis.py \
    --model_path output/models/best_model.pt \
    --threshold 0.8 \
    --output_dir results/error_analysis

Identifies:

  • False positive patterns
  • False negative patterns
  • Confidence-error relationships
  • Graph structure issues

Detailed guide: guide/ERROR_ANALYSIS_GUIDE.md

Disease-Specific Case Studies

# Analyze top drug predictions for a disease
python src/case_studies.py --disease "Type 2 Diabetes" --top_k 10

# With confidence threshold
python src/case_studies.py \
    --disease "Alzheimer disease" \
    --top_k 20 \
    --threshold 0.7

# GPU-accelerated for faster inference
python src/case_studies.py \
    --disease "cancer" \
    --model_path output/models/final_model.pt \
    --output_dir results_final/case_studies

Generates:

  • Case study report: Top predictions with biological insights
  • Prediction scores plot: Bar chart of confidence scores
  • Network diagrams: Drug-gene-disease pathways
  • JSON export: Machine-readable results

Features:

  • Known vs novel predictions
  • Connection paths through genes
  • Mechanistic interpretations
  • Medical recommendations

Detailed guide: guide/CASE_STUDIES_GUIDE.md

Embedding Visualization

# Visualize learned embeddings
python src/visualize_embeddings.py --model_path output/models/best_model.pt

# Sample fewer nodes for faster visualization
python src/visualize_embeddings.py \
    --model_path output/models/best_model.pt \
    --sample_size 5000 \
    --method tsne

# With clustering analysis
python src/visualize_embeddings.py \
    --model_path output/models/best_model.pt \
    --cluster \
    --n_clusters 8

Visualizations:

  • t-SNE/UMAP projections colored by node type
  • Clustering analysis with silhouette scores
  • Distance matrices (drug-drug, disease-disease, drug-disease)
  • Nearest neighbor analysis
  • Interactive HTML plots (optional)

Detailed guide: guide/VISUALIZE_EMBEDDINGS_GUIDE.md

Explainable Predictions

# Explain a specific prediction
python src/explain_predictions.py \
    --drug "Metformin" \
    --disease "diabetes mellitus" \
    --top_k 5

# Batch explanation for multiple pairs
python src/explain_predictions.py \
    --drug "Aspirin" \
    --disease "heart disease" \
    --top_k 10

Generates:

  • Path-based explanations: Drug → Gene → Gene → Disease
  • Natural language summaries: Human-readable explanations
  • Network visualizations: Graph showing paths and importance
  • Path ranking: Top-K most important paths
  • Sankey diagrams: Flow visualization (optional)

Example output:

"Metformin may treat diabetes mellitus through a pathway 
involving PRKAB1, PRKAA2, and RFX6. This connection suggests 
a 4-step mechanism linking the drug's molecular targets to 
the disease pathology."

Detailed guide: guide/EXPLAIN_PREDICTIONS_GUIDE.md

Medical Validation

# Validate top novel predictions
python src/medical_validation.py --top_k 50

# Custom threshold and sampling
python src/medical_validation.py \
    --top_k 100 \
    --threshold 0.7 \
    --sample_diseases 100

Validation criteria:

  • Drug targets disease-related genes
  • Common gene neighbors exist
  • Literature evidence found (mock)
  • Clinical trials exist (mock)
  • Multiple connecting pathways

Outputs:

  • Validation report: High/medium/low confidence predictions
  • CSV export: Detailed scores and checklists
  • Validation scores: Weighted assessment (0-1)

Detailed guide: guide/MEDICAL_VALIDATION_GUIDE.md

Baseline Comparison

# Compare with baselines
python src/compare_methods.py --methods random degree rgcn

# Include TransE baseline
python src/compare_methods.py \
    --methods random degree transe rgcn \
    --transe_epochs 50

# Full analysis with all plots
python src/compare_methods.py \
    --methods random degree rgcn \
    --frequency_analysis \
    --statistical_tests

Baselines:

  • Random: Random predictions (lower bound)
  • Node Degree: Popularity-based predictions
  • TransE: Translation-based embeddings
  • RGCN: Your model

Outputs:

  • Comparison bar charts for all metrics
  • Performance by disease frequency
  • Statistical significance heatmap
  • LaTeX/Markdown tables for papers

Detailed guide: guide/METHOD_COMPARISON_GUIDE.md

Failure Analysis

# Deep dive into prediction failures
python src/analyze_failures.py --num_failures 5 --num_successes 5

# With subgraph visualizations
python src/analyze_failures.py \
    --num_failures 10 \
    --num_successes 10 \
    --visualize_subgraphs \
    --num_samples 10000

Analysis:

  • Identifies worst predictions (high confidence but wrong)
  • Compares with correct predictions
  • Visualizes subgraphs around failures
  • Generates failure hypotheses
  • Suggests model improvements

Example findings:

  • "Model fails when there are FEW CONNECTING PATHS (0.4 vs 7.6)"
  • "Model makes more FALSE POSITIVES due to high-degree nodes"
  • "Failures occur in SPARSE NEIGHBORHOODS"

Suggestions:

  • Add attention mechanisms
  • Increase GCN layers
  • Add negative sampling
  • Use degree normalization

Detailed guide: guide/ANALYZE_FAILURES_GUIDE.md

Interpreting Results

Understanding Output Directories

Each analysis creates a subdirectory in results_final/:

results_final/
├── evaluation/          # Basic metrics (AUC, Hits@K, MRR)
├── analysis/            # Performance by disease frequency, relation type
├── error_analysis/      # False positive/negative patterns
├── case_studies/        # Top predictions for specific diseases
├── embeddings/          # t-SNE/UMAP visualizations, clusters
├── explanations/        # Path-based reasoning for predictions
├── validation/          # Biological plausibility scores
├── comparison/          # Baseline method comparisons
└── failures/            # Failure case deep-dives

Key Output Files

File Purpose When to Use
evaluation_results.csv Basic metrics overview First check of model performance
performance_by_disease_frequency.png Bias analysis Check if model favors common diseases
error_patterns.txt Common failure modes Understand systematic errors
alzheimers_predictions.csv Disease-specific predictions Validate domain knowledge
embedding_clusters_report.txt Entity clustering Discover entity groupings
explanation_summary.txt Top explanations Understand model reasoning
validation_report.txt Biological validation Assess medical plausibility
paper_table_latex.txt Method comparison table Include in publications
failure_analysis_report.txt Error hypotheses Guide model improvements

Recommended Workflow

  1. Quick Assessment: Run evaluation + case studies

    python src/run_full_analysis.py \
      --model_path output/models/final_model.pt \
      --output_dir results_final \
      --analyses evaluate case_studies
  2. Deep Dive: Add error analysis + explanations for specific insights

  3. Publication: Run full analysis suite for comprehensive reporting

  4. Model Improvement: Focus on failures + validation to guide next iteration

Contributing

Contributions are welcome! Here are some areas for improvement:

High-Priority Enhancements

  1. Real Medical Validation

    • Integrate with PubMed API for literature validation
    • Connect to ClinicalTrials.gov for trial data
    • Add DrugBank integration
  2. Advanced Baselines

    • Implement ComplEx, RotatE embeddings
    • Include rule-based methods
  3. Interpretability

    • GAT attention weight visualization
    • GNNExplainer integration
    • Counterfactual explanations
  4. Scalability

    • Distributed training support
    • Mini-batch sampling for large graphs
    • Model compression techniques

How to Contribute

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/AmazingFeature)
  3. Make your changes
  4. Add tests if applicable
  5. Commit your changes (git commit -m 'Add AmazingFeature')
  6. Push to the branch (git push origin feature/AmazingFeature)
  7. Open a Pull Request

Code Style

  • Follow PEP 8 guidelines
  • Add docstrings to all functions
  • Include type hints where possible
  • Write descriptive commit messages

Citation

If you use this code in your research, please cite:

@software{primekg_rgcn_2025,
  author = {arnold117},
  title = {PrimeKG-RGCN-LinkPrediction: Drug-Disease Link Prediction with Relational Graph Convolutional Networks},
  year = {2025},
  url = {https://github.com/arnold117/PrimeKG-RGCN-LinkPrediction}
}

And the PrimeKG dataset:

@article{chandak2023building,
  title={Building a knowledge graph to enable precision medicine},
  author={Chandak, Payal and Huang, Kexin and Zitnik, Marinka},
  journal={Nature Scientific Data},
  volume={10},
  number={1},
  pages={67},
  year={2023},
  publisher={Nature Publishing Group}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • PrimeKG Team at Harvard Medical School for the knowledge graph
  • PyTorch Geometric team for the GNN framework
  • Open-source community for various analysis tools

Contact

For questions or issues:

  • Open an issue on GitHub

FAQ

Q: How long does training take?
A: ~4-5 hours on a single GPU (RTX 1070) for 100 epochs.

Q: Can I use this on my own knowledge graph?
A: Yes! Modify src/data/preprocess.py to load your data format.

Q: What GPU memory is required?
A: Minimum 2GB. Our model uses less than 1GB during training.

Q: How do I handle OOM errors?
A: Reduce batch size or hidden dimensions in src/train.py.

Q: Can I add more baseline methods?
A: Yes! Extend the BaselineMethod class in src/compare_methods.py.

Q: How accurate are the medical validations?
A: Current implementation uses proxy signals. For production use, integrate real biomedical databases.


About

Relational Graph Convolutional Network (RGCN) for predicting drug-disease indications using PrimeKG knowledge graph. Leverages drug-gene and gene-disease relationships for multi-hop reasoning in biomedical link prediction.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors