diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 13a37103b..060679582 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -448,6 +448,8 @@ class PackedMemMapDatasetMegatronConfig(BaseModel): class CombinedDatasetConfig(BaseModel): datasets: list[PydanticDatasetIFType] + log_chunk_switch: bool = False + log_initial_pos: bool = False class BatchSamplerConfig(BaseModel): diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 0ef4d9076..e4e26f02b 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -14,6 +14,7 @@ from modalities.dataloader.create_packed_data import EmbeddedStreamData from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper +from modalities.utils.logger_utils import get_logger class Dataset(TorchdataSet): @@ -445,14 +446,21 @@ class CombinedDataset(Dataset): In the Dataloader, a batch will still contain packed samples from different datasets. """ - def __init__(self, datasets: list[Dataset]): + def __init__(self, datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False): """Initializes the CombinedDataset object, combining multiple datasets. Args: datasets (list[Dataset]): A list of datasets to combine. + log_chunk_switch (bool, optional): Whether to log when switching between dataset chunks. Defaults to False. + log_initial_pos (bool, optional): Whether to log the initial position of at the beginning of a training + or warmstart. Defaults to False. """ + self.log_chunk_switch = log_chunk_switch + self.log_initial_pos = log_initial_pos + self.already_logged_initial_pos = False self.datasets = datasets self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64) + self.logger = get_logger(__name__) def __len__(self) -> int: return self.cumulative_sizes[-1] @@ -460,5 +468,17 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> dict: dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right") local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0) + if self.log_chunk_switch and local_idx == 0: + self.logger.info( + f"Chunk switch: global_index={idx}/{len(self)} chunk index={dataset_idx}/{len(self.datasets)}, " + f"local index={local_idx}/{len(self.datasets[dataset_idx])}" + ) + + if self.log_initial_pos and not self.already_logged_initial_pos: + self.logger.info( + f"Initial pos: global_index={idx}/{len(self)} chunk index={dataset_idx}/{len(self.datasets)}, " + f"local index={local_idx}/{len(self.datasets[dataset_idx])}" + ) + self.already_logged_initial_pos = True return self.datasets[dataset_idx][local_idx] diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 1eab9e328..8539c5d3a 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -116,13 +116,18 @@ def get_packed_mem_map_dataset_megatron( return dataset @staticmethod - def get_combined_dataset(datasets: list[Dataset]) -> Dataset: + def get_combined_dataset( + datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False + ) -> Dataset: """Factory method for creating a combined datset . Args: datasets (list[Dataset]): List of datasets to combine. + log_chunk_switch (bool, optional): Whether to log when switching between dataset chunks. Defaults to False. + log_initial_pos (bool, optional): Whether to log the initial position of at the beginning of a + training or warmstart. Defaults to False. Returns: Dataset: CombinedDataset object. """ - return CombinedDataset(datasets=datasets) + return CombinedDataset(datasets=datasets, log_chunk_switch=log_chunk_switch, log_initial_pos=log_initial_pos)