-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinstall.sh
More file actions
executable file
·223 lines (181 loc) · 6.34 KB
/
install.sh
File metadata and controls
executable file
·223 lines (181 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/bin/bash
# Kascade Installation Script
# This script sets up all dependencies for Kascade
# Prerequisites: conda environment with Python 3.12 must be created and activated
set -e # Exit on error
# Configuration
THIRD_PARTY_DIR="third_party"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
print_step() {
echo -e "${GREEN}==>${NC} $1"
}
print_warning() {
echo -e "${YELLOW}Warning:${NC} $1"
}
print_error() {
echo -e "${RED}Error:${NC} $1"
}
# Check prerequisites
check_prerequisites() {
print_step "Checking prerequisites..."
# Check if conda environment is activated
if [ -z "$CONDA_DEFAULT_ENV" ]; then
print_error "No conda environment detected. Please create and activate a conda environment first:"
echo " conda create -n kascade python=3.12"
echo " conda activate kascade"
exit 1
fi
# Check Python version
PYTHON_VERSION=$(python -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
if [[ "$PYTHON_VERSION" != "3.12" ]]; then
print_warning "Python version is $PYTHON_VERSION, recommended is 3.12"
fi
# Check if CUDA is available
if ! command -v nvcc &> /dev/null; then
print_warning "nvcc not found. CUDA toolkit may not be installed."
fi
# Check if git is available
if ! command -v git &> /dev/null; then
print_error "git is not installed. Please install git first."
exit 1
fi
echo " Conda environment: $CONDA_DEFAULT_ENV"
echo " Python version: $PYTHON_VERSION"
}
# Detect GPU architecture
detect_gpu_arch() {
print_step "Detecting GPU architecture..."
if ! command -v nvidia-smi &> /dev/null; then
print_warning "nvidia-smi not found. Cannot detect GPU architecture."
echo "unknown"
return
fi
COMPUTE_CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | cut -d'.' -f1)
echo "$COMPUTE_CAP"
}
# Install PyTorch with CUDA 12.8
install_pytorch() {
print_step "Installing PyTorch with CUDA 12.8..."
pip install torch==2.7.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
}
# Install CUDA via conda
install_cuda() {
print_step "Installing CUDA 12.8.1 via conda..."
conda install -y cuda -c nvidia/label/cuda-12.8.1
echo ">> Configuring CUDA environment..."
# Link the hidden headers to the standard location
SOURCE_DIR="$CONDA_PREFIX/targets/x86_64-linux/include"
DEST_DIR="$CONDA_PREFIX/include"
# Loop through every file in the source
for filepath in "$SOURCE_DIR"/*; do
filename=$(basename "$filepath")
# Only link if the destination does NOT exist yet
if [ ! -e "$DEST_DIR/$filename" ]; then
ln -s "$filepath" "$DEST_DIR/$filename"
fi
done
}
# Install build tools
install_build_tools() {
print_step "Installing build tools..."
pip install packaging ninja
}
# Initialize and update git submodules
init_submodules() {
print_step "Initializing git submodules..."
git submodule update --init --recursive
}
# Install TileLang
install_tilelang() {
print_step "Installing TileLang..."
# Clean build artifacts if they exist
if [ -d "$THIRD_PARTY_DIR/tilelang/build" ]; then
print_warning "Cleaning TileLang build artifacts..."
rm -rf "$THIRD_PARTY_DIR/tilelang/build"
rm -rf "$THIRD_PARTY_DIR/tilelang/tilelang.egg-info"
rm -rf "$THIRD_PARTY_DIR/tilelang/*.egg-info"
fi
# Install TileLang
cd "$THIRD_PARTY_DIR/tilelang"
pip install .
cd ../..
}
# Install Flash Attention 3
install_flash_attention_3() {
print_step "Installing Flash Attention 3..."
# Build Flash Attention 3 from hopper directory
cd "$THIRD_PARTY_DIR/flash-attention/hopper"
python setup.py install
cd ../../..
}
# Install Flash Attention 2
install_flash_attention_2() {
print_step "Installing Flash Attention 2..."
pip install flash-attn==2.8.0.post2 --no-build-isolation
}
# Install Kascade package
install_kascade() {
print_step "Installing Kascade package..."
pip install -e .
}
# Main installation flow
main() {
echo "============================================"
echo " Kascade Installation Script"
echo "============================================"
echo ""
check_prerequisites
# Detect GPU architecture
GPU_ARCH=$(detect_gpu_arch)
echo " Detected GPU architecture: $GPU_ARCH"
# Get submodule commits for display
TILELANG_COMMIT=$(cd "$THIRD_PARTY_DIR/tilelang" 2>/dev/null && git rev-parse HEAD 2>/dev/null || echo "not initialized")
FA_COMMIT=$(cd "$THIRD_PARTY_DIR/flash-attention" 2>/dev/null && git rev-parse HEAD 2>/dev/null || echo "not initialized")
echo ""
echo "This script will install:"
echo " 1. PyTorch with CUDA 12.8"
echo " 2. CUDA 12.8.1 (via conda)"
echo " 3. Build tools (packaging, ninja)"
echo " 4. TileLang (submodule: ${TILELANG_COMMIT:0:7})"
if [ "$GPU_ARCH" -ge 9 ] 2>/dev/null; then
echo " 5. Flash Attention 3 (submodule: ${FA_COMMIT:0:7}) [H100/H800 detected]"
else
echo " 5. Flash Attention 2 only (v2.8.0.post2) [Non-Hopper GPU]"
fi
echo " 6. Kascade and its dependencies"
echo ""
read -p "Continue with installation? [y/N] " -n 1 -r
echo ""
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "Installation cancelled."
exit 0
fi
install_pytorch
install_cuda
install_build_tools
init_submodules
install_tilelang
if [ "$GPU_ARCH" -ge 9 ] 2>/dev/null; then
print_step "Hopper GPU detected - Installing Flash Attention 3..."
install_flash_attention_3
else
print_step "Non-Hopper GPU detected - Installing Flash Attention 2 only..."
print_warning "Flash Attention 3 requires H100/H800 GPUs (Hopper architecture)"
install_flash_attention_2
fi
install_kascade
echo ""
echo "============================================"
echo -e "${GREEN}Installation complete!${NC}"
echo "============================================"
echo ""
if [ "$GPU_ARCH" -lt 9 ] 2>/dev/null; then
echo -e "${YELLOW}Note: Not using hopper architecture. Efficient Kascade and FA-3 will not work${NC}"
echo ""
fi
}
main "$@"