Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ class PackedMemMapDatasetMegatronConfig(BaseModel):

class CombinedDatasetConfig(BaseModel):
datasets: list[PydanticDatasetIFType]
log_chunk_switch: bool = False
log_initial_pos: bool = False


class BatchSamplerConfig(BaseModel):
Expand Down
22 changes: 21 additions & 1 deletion src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from modalities.dataloader.create_packed_data import EmbeddedStreamData
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.utils.logger_utils import get_logger


class Dataset(TorchdataSet):
Expand Down Expand Up @@ -445,20 +446,39 @@ class CombinedDataset(Dataset):
In the Dataloader, a batch will still contain packed samples from different datasets.
"""

def __init__(self, datasets: list[Dataset]):
def __init__(self, datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False):
"""Initializes the CombinedDataset object, combining multiple datasets.

Args:
datasets (list[Dataset]): A list of datasets to combine.
log_chunk_switch (bool, optional): Whether to log when switching between dataset chunks. Defaults to False.
log_initial_pos (bool, optional): Whether to log the initial position of at the beginning of a training
or warmstart. Defaults to False.
"""
self.log_chunk_switch = log_chunk_switch
self.log_initial_pos = log_initial_pos
self.already_logged_initial_pos = False
self.datasets = datasets
self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64)
self.logger = get_logger(__name__)

def __len__(self) -> int:
return self.cumulative_sizes[-1]

def __getitem__(self, idx: int) -> dict:
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right")
local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0)
if self.log_chunk_switch and local_idx == 0:
self.logger.info(
f"Chunk switch: global_index={idx}/{len(self)} chunk index={dataset_idx}/{len(self.datasets)}, "
f"local index={local_idx}/{len(self.datasets[dataset_idx])}"
)

if self.log_initial_pos and not self.already_logged_initial_pos:
self.logger.info(
f"Initial pos: global_index={idx}/{len(self)} chunk index={dataset_idx}/{len(self.datasets)}, "
f"local index={local_idx}/{len(self.datasets[dataset_idx])}"
)
self.already_logged_initial_pos = True

return self.datasets[dataset_idx][local_idx]
9 changes: 7 additions & 2 deletions src/modalities/dataloader/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,18 @@ def get_packed_mem_map_dataset_megatron(
return dataset

@staticmethod
def get_combined_dataset(datasets: list[Dataset]) -> Dataset:
def get_combined_dataset(
datasets: list[Dataset], log_chunk_switch: bool = False, log_initial_pos: bool = False
) -> Dataset:
"""Factory method for creating a combined datset .

Args:
datasets (list[Dataset]): List of datasets to combine.
log_chunk_switch (bool, optional): Whether to log when switching between dataset chunks. Defaults to False.
log_initial_pos (bool, optional): Whether to log the initial position of at the beginning of a
training or warmstart. Defaults to False.

Returns:
Dataset: CombinedDataset object.
"""
return CombinedDataset(datasets=datasets)
return CombinedDataset(datasets=datasets, log_chunk_switch=log_chunk_switch, log_initial_pos=log_initial_pos)
Loading