Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import sys
import importlib.util
import importlib.metadata
import packaging.version
import platform
import json
import shlex
import tempfile
from functools import lru_cache

from modules import cmd_args, errors
Expand Down Expand Up @@ -277,7 +279,7 @@ def run_extensions_installers(settings_file):
startup_timer.record(dirname_extension)


re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:([<>=]=)\s*([-+_.a-zA-Z0-9]+))?\s*")


def requirements_met(requirements_file):
Expand All @@ -291,25 +293,31 @@ def requirements_met(requirements_file):

with open(requirements_file, "r", encoding="utf8") as file:
for line in file:
if line.strip() == "":
if line.strip() == "" or line[0] == '#':
continue

m = re.match(re_requirement, line)
if m is None:
return False

package = m.group(1).strip()
version_required = (m.group(2) or "").strip()
version_required = (m.group(3) or "").strip()

if version_required == "":
continue

condition = m.group(2)

try:
version_installed = importlib.metadata.version(package)
except Exception:
return False

if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
if condition == '==' and packaging.version.parse(version_installed) != packaging.version.parse(version_required):
return False
elif condition == '>=' and packaging.version.parse(version_installed) < packaging.version.parse(version_required):
return False
elif condition == '<=' and packaging.version.parse(version_installed) > packaging.version.parse(version_required):
return False

return True
Expand Down Expand Up @@ -419,15 +427,57 @@ def prepare_environment():
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)


def install_minimal_requirements(requirements_file, desc=None):
with open(requirements_file, "r", encoding="utf8") as req:
# revert >= condition to == for non installed package to install minimal required version
lines = req.readlines()

requirements = []
for line in lines:
if line.strip() == "" or line[0] == '#':
continue

m = re.match(re_requirement, line)
if m is None:
continue

package = m.group(1).strip()
version_required = (m.group(3) or "").strip()

if version_required == "":
requirements.append(line)
continue

condition = m.group(2)

try:
version_installed = importlib.metadata.version(package)

if condition == '>=' and packaging.version.parse(version_installed) < packaging.version.parse(version_required):
line = line.replace(">=", "==")
except Exception:
line = line.replace(">=", "==")
requirements.append(line)


temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w")

temp_file.writelines(requirements)
temp_file.close()
run_pip(f"install -r \"{temp_file.name}\"", desc)
os.remove(temp_file.name)


if not requirements_met(requirements_file):
run_pip(f"install -r \"{requirements_file}\"", "requirements")
install_minimal_requirements(requirements_file, "requirements")
startup_timer.record("install requirements")

if not os.path.isfile(requirements_file_for_npu):
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)

if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
install_minimal_requirements(requirements_file, "requirements_for_npu")
startup_timer.record("install requirements_for_npu")

if not args.skip_install:
Expand Down