Skip to content

Commit 3be2921

Browse files
committed
feat: moved find_free_port from test to src to use it in dcp to huggingface conversion
1 parent 436bf95 commit 3be2921

9 files changed

Lines changed: 21 additions & 16 deletions

File tree

src/modalities/conversion/gpt2/conversion_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from modalities.models.utils import ModelTypeEnum, get_model_from_config
1717
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
1818
from modalities.running_env.env_utils import PyTorchDtypes
19+
from modalities.utils.ports import find_free_port
1920

2021

2122
def convert_model_checkpoint(modalities_config: ConfigDictType) -> tuple[GPT2ForCausalLM | LlamaForCausalLM, GPT2LLM]:
@@ -110,7 +111,9 @@ def check_converted_dcp_model(
110111
vocab_size: int = new_config["model_raw" if "model_raw" in new_config else "model"]["config"]["vocab_size"]
111112
if isinstance(device_id_modalities, str):
112113
device_id_modalities = int(device_id_modalities.replace("cuda:", ""))
113-
with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=device_id_modalities):
114+
with MultiProcessingCudaEnv(
115+
ProcessGroupBackendType.nccl, 0, 0, 1, find_free_port(), device_id=device_id_modalities
116+
):
114117
modalities_model = get_model_from_config(new_config, model_type=ModelTypeEnum.DCP_CHECKPOINTED_MODEL)
115118
check_converted_model(hf_model, modalities_model, num_testruns=num_testruns, vocab_size=vocab_size)
116119

src/modalities/utils/ports.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import socket
2+
3+
4+
def find_free_port():
5+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
6+
s.bind(("127.0.0.1", 0))
7+
port = s.getsockname()[1]
8+
s.close()
9+
return port

tests/conversion/gpt2/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from modalities.registry.registry import Registry
2424
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
2525
from modalities.training.training_progress import TrainingProgress
26+
from modalities.utils.ports import find_free_port
2627
from tests.conftest import _ROOT_DIR
27-
from tests.utility import find_free_port, monitor_child_processes
28+
from tests.utility import monitor_child_processes
2829

2930

3031
@pytest.fixture

tests/dataloader/distributed/test_distributed_multidim_dataloader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from modalities.dataloader.sampler_factory import SamplerFactory
1111
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
1212
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_mesh_for_parallelism_method
13+
from modalities.utils.ports import find_free_port
1314
from tests.dataloader.distributed.mocks import MultiProcessingCudaEnvMock
1415
from tests.dataloader.dummy_sequential_dataset import TestDataset
1516
from tests.mocks import MockDeviceMesh
16-
from tests.utility import find_free_port, tensors_equal_across_mesh, tensors_pairwise_not_equal_across_mesh
17+
from tests.utility import tensors_equal_across_mesh, tensors_pairwise_not_equal_across_mesh
1718

1819

1920
@pytest.mark.parametrize("world_size, dp_degree", [(4, 2)])

tests/fsdp2_parallelization/test_tensor_parallelism.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from modalities.models.gpt2.gpt2_model import TransformerMLP
1818
from modalities.models.model import SwiGLU
1919
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
20-
from tests.utility import find_free_port
20+
from modalities.utils.ports import find_free_port
2121

2222

2323
def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path:

tests/test_optimizer_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from modalities.registry.registry import Registry
2020
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
2121
from modalities.running_env.env_utils import MixedPrecisionSettings
22+
from modalities.utils.ports import find_free_port
2223
from tests.conftest import _ROOT_DIR
23-
from tests.utility import find_free_port
2424

2525
# number of parameters for each optimizer group
2626
GPT2_LINEAR = 66130944

tests/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType
1414
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
1515
from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters
16+
from modalities.utils.ports import find_free_port
1617
from modalities.utils.typing_utils import FSDPX
17-
from tests.utility import find_free_port
1818

1919

2020
def test_get_local_number_of_trainable_parameters():

tests/training/gradient_clipping/test_fsdp_gradient_clipper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
FSDP2LoggingOnlyGradientClipper,
1616
GradientClippingMode,
1717
)
18-
from tests.utility import find_free_port
18+
from modalities.utils.ports import find_free_port
1919

2020

2121
class MockFSDPModel:

tests/utility.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import socket
32
import time
43
from multiprocessing import Queue
54
from multiprocessing.managers import SyncManager
@@ -13,14 +12,6 @@
1312
from modalities.batch import DatasetBatch
1413

1514

16-
def find_free_port():
17-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
18-
s.bind(("127.0.0.1", 0))
19-
port = s.getsockname()[1]
20-
s.close()
21-
return port
22-
23-
2415
def add_debugger_to_distributed_test():
2516
"""Add a debugger to a distributed test.
2617
This function should be called at the beginning of the test.

0 commit comments

Comments
 (0)