Training

Efficient training tricks like Sequence Packing
from dart_math.train import *

Accelerating Several Times with Sequence Packing in 4 Lines of Code

Our interfaces can be integrated with the HuggingFace datasets in 4 lines of code:

from dart_math.train import monkey_patch4pack, make_supervised_dset
# ...
monkey_patch4pack(model)
pack_dset = make_supervised_dset(tokenizer=tokenizer, data_path=data_args.data_path, pack_len=training_args.model_max_length, query_field=data_args.query_field,, resp_field=data_args.resp_field,, prompt_template=data_args.prompt_template)
trainer = Trainer(model=model, tokenizer=tokenizer, train_dataset=pack_dset)

monkey_patch4pack would monkey-patch the model’s _get_unpad_data method.

make_supervised_dset would

  1. load, tokenize and cache the dataset;
  2. pack the data points into computation sequences.

For a more detailed usage example, please refer to our training script for DART-Math.

Besides, for general datasets objects that with the form [{"input_ids": [...], "labels": [...], "attention_mask"}: [...]}, ...], you can use PackedDataset to wrap it to apply sequence packing:

from dart_math.train import PackedDataset
# ...
dset = PackedDataset(dataset=dset, tokenizer=tokenizer, pack_len=4096)

source

monkey_patch4pack

 monkey_patch4pack (name_or_cls_or_obj:str|type|transformers.configuration
                    _utils.PretrainedConfig)

Monkey patch the modeling module for packing. Must be called before instantiating the model.

Type Details
name_or_cls_or_obj str | type | transformers.configuration_utils.PretrainedConfig Name containing the model name like “llama” / “mistral” / …
Returns None

source

make_supervised_dset

 make_supervised_dset
                       (tokenizer:transformers.tokenization_utils.PreTrain
                       edTokenizer, data_path:str|list[str],
                       query_field:str|list[str]='query',
                       resp_field:str|list[str]='response',
                       tokenized_cache_home:str='/home/runner/work/dart-
                       math/dart-math/data/cache-tokenized',
                       shuffle_seed:int=42, pack_len:int=None, prompt_temp
                       late:str|dict[str,str]|dart_math.utils.PromptTempla
                       te='alpaca')

Make dataset for supervised fine-tuning.

Type Default Details
tokenizer PreTrainedTokenizer (HF) tokenizer.
data_path str | list[str] Dataset ID or path.
query_field str | list[str] query Field name for query.
resp_field str | list[str] response Field name for response.
tokenized_cache_home str /home/runner/work/dart-math/dart-math/data/cache-tokenized Path to the tokenized cache home. Useful when repeatedly training on large datasets. None or “” means no cache.
shuffle_seed int 42 Seed for shuffling the dataset before packing. None or negative means no shuffling.
pack_len int None Maximum length of packed computation sequence in token. None / Non-positive means no packing.
prompt_template str | dict[str, str] | dart_math.utils.PromptTemplate alpaca ID / File path / PromptTemplate object of prompt template.
Returns dart_math.train.TokenizedSupervisedDataset | dart_math.train.PackedDataset Dataset ready for input to Trainer, containing the following fields at least: "input_ids", "labels", and "attention_mask".

Sequence Packing

Sequence Packing Accelerates 6-8x than Simple Batching

Simple batching that pad every data sequence to the maximum training length wastes a lot computation and memory on padding tokens, especially for short data sequences and long maximum training length.

For example, if the model maximum training length is 4096 (as in most base models like Mistral-7B and the longest data sequences in some datasets like MATH), and data sequences are ~512 tokens long on average (as in most math SFT datasets), we waste almost 1-1/8=7/8 computation and memory on padding tokens.

Sequence packing can eliminate the waste almost completely, without affecting the training dynamics (for most models nowadays), except for the number of data sequences in one batch .

In the example above, we can accelerate about 6-8x with sequence packing.

Basic Idea of Sequence Packing

The basic idea of sequence packing is

  • to merge/pack short data sequences into a single conputation sequence as long as the maximum training length to eliminate most watse on padding tokens,
  • while trying best to not affecting the training dynamics by
    • manipulating attention masks to avoid cross-contamination between different data sequences,
    • working with relative positional encoding to avoid the positional information mismatch for the non-first data sequences in the packed computation sequence.

Manipulating Attention Masks to Avoid Cross-Contamination

Concretely, when we pack inputs, the attention should be only within individual sequences. For example, assume that we are packing 2 inputs: packed input = [input 1] [input 2]. Tokens from input 1 only attend to tokens from input 1 and tokens from input 2 only attend to tokens from input 2

Examples of packing 2 input sequences: “good morning my name is John” and “This is a dog”. The first one is the attention matrix of packing with cross-contamination, the second one is the correct attention matrix of packing.

c.f. https://github.com/MeetKai/functionary/tree/main/functionary/train/packing

Relative Positinal Encoding Perferctly Works with Sequence Packing

At first glance, sequence packing introduces another problem: the positional encodings of the non-first data sequences in one computation sequence are not the same as the vanilla non-packing setting.

This is indeed a problem for absolute positional encoding, but practically does not matter for relative positional encoding like RoPE, which is almost the de facto practice nowadays.

API Reference


source

PackedDataset

 PackedDataset
                (dataset:torch.utils.data.dataset.Dataset|datasets.arrow_d
                ataset.Dataset, tokenizer:transformers.tokenization_utils.
                PreTrainedTokenizer, pack_len:int, shuffle_seed:int=42)

Packed dataset containing computation sequences.

Type Default Details
dataset torch.utils.data.dataset.Dataset | datasets.arrow_dataset.Dataset Original tokenized dataset, which should have the following fields at least: "input_ids", "labels", and "attention_mask".
tokenizer PreTrainedTokenizer (HF) tokenizer.
pack_len int Maximum length of packed compuation sequence in token.
shuffle_seed int 42 Seed for shuffling the dataset before packing. None / Negative values mean no shuffling.

source

PackedDataset.stat

 PackedDataset.stat ()

Print out the statistics of the packed dataset. Original -> Packed: 1. Number of data/computation sequences; 2. Average effective length of compution sequences.


source

TokenizedSupervisedDataset

 TokenizedSupervisedDataset
                             (tokenizer:transformers.tokenization_utils.Pr
                             eTrainedTokenizer,
                             input_ids:list[torch.Tensor]=None,
                             labels:list[torch.Tensor]=None,
                             attention_mask:list[torch.Tensor]=None)

Tokenized dataset for supervised fine-tuning.

Type Default Details
tokenizer PreTrainedTokenizer (HF) tokenizer. None for empty dataset.
input_ids list None List of input token ID sequences.
labels list None List of label sequences.
attention_mask list None List of attention mask sequences.

source

TokenizedSupervisedDataset.load_from_raw_dset

 TokenizedSupervisedDataset.load_from_raw_dset
                                                (tokenizer:transformers.to
                                                kenization_utils.PreTraine
                                                dTokenizer, data_path:str,
                                                query_field:str='query',
                                                resp_field:str='response',
                                                prompt_template:str|dict[s
                                                tr,str]|dart_math.utils.Pr
                                                omptTemplate='alpaca')

Load a dataset from a file and tokenize it.

Type Default Details
tokenizer PreTrainedTokenizer (HF) tokenizer.
data_path str Dataset ID or path.
query_field str query Field name for query.
resp_field str response Field name for response.
prompt_template str | dict[str, str] | dart_math.utils.PromptTemplate alpaca ID / File path / PromptTemplate object of prompt template.
Returns TokenizedSupervisedDataset

source

TokenizedSupervisedDataset.__getitem__

 TokenizedSupervisedDataset.__getitem__ (i:int)

Get a data point.

Type Details
i int dataset[i]
Returns dict {"input_ids": input_ids[i], "labels": labels[i], "attention_mask": attention_mask[i]}

source

TokenizedSupervisedDataset.concat

 TokenizedSupervisedDataset.concat
                                    (datasets:list['TokenizedSupervisedDat
                                    aset'])

Concatenate TokenizedSupervisedDataset instances to the current dataset. datasets : listTokenizedSupervisedDataset List of tokenized datasets to concatenate. Each dataset should have the following fields at least: "input_ids", "labels", and "attention_mask".


source

TokenizedSupervisedDataset.shuffle

 TokenizedSupervisedDataset.shuffle (seed:int=42)

Shuffle the dataset.


source

TokenizedSupervisedDataset.pad

 TokenizedSupervisedDataset.pad ()

Pad the dataset to the same length of the longest data point.

Acknowledgements

Thanks to https://github.com/MeetKai/functionary/tree/main/functionary/train/packing. The code for sequence packing is largely based on it.

Back to top