Quick Start

Using dart-math in several lines of code.

Installation

We recommend using Conda and pip to manage your environment. Run the following commands to setup your environment:

git clone https://github.com/hkust-nlp/dart-math.git && cd dart-math
conda create --name dart-math --yes python=3.11
conda activate dart-math
pip install -r requirements.txt
pip install flash-attn --no-build-isolation

For common users/developers, please just run the following command the install the dart-math package:

pip install -e "."

For intended contributors, we recommend installing the package with the dev extras:

pip install -e ".[dev]"
pip install pre-commit

dart_math.train: Efficient Training Tricks

Accelerating Several Times with Sequence Packing in 4 Lines of Code

Our interfaces can be integrated with the HuggingFace datasets in 4 lines of code:

from dart_math.train import monkey_patch4pack, make_supervised_dset
# ...
monkey_patch4pack(model)
pack_dset = make_supervised_dset(tokenizer=tokenizer, data_path=data_args.data_path, pack_len=training_args.model_max_length, query_field=data_args.query_field,, resp_field=data_args.resp_field, prompt_template=data_args.prompt_template)
trainer = Trainer(model=model, tokenizer=tokenizer, train_dataset=pack_dset)

monkey_patch4pack would monkey-patch the model’s _get_unpad_data method.

make_supervised_dset would

  1. load, tokenize and cache the dataset;
  2. pack the data points into computation sequences.

For a more detailed usage example, please refer to our training script for DART-Math.

Besides, for general datasets objects that with the form [{"input_ids": [...], "labels": [...], "attention_mask"}: [...]}, ...], you can use PackedDataset to wrap it to apply sequence packing:

from dart_math.train import PackedDataset
# ...
dset = PackedDataset(dataset=dset, tokenizer=tokenizer, pack_len=4096)

For more details or more general interfaces, please refer to the document of dart_math.train.

dart_math.gen – Efficient Generation with Flexible Stopping Criteria

Difficulty-Aware Rejection Sampling (with Code Execution) in 5 Lines of Code

from dart_math.data import load_query_dps
from dart_math.gen import gen, is_dp_dars_finished
from dart_math.eval import EvaluatorMathBatch
# ...
generator = Generator(llm, sampling_params, resp_sample_cls=RespSampleVLLM, batch_evaluator=(EvaluatorMathBatch() if not args.gen_only else None), code_exec_cfg=CodeExecCfg.load_from_id_or_path(args.code_exec_cfg) if args.code_exec_cfg else None)
generator.gen(query_dps=query_dps, dp_stop_criteria=is_dp_dars_finished, save_path=args.gen_save_path, n_paths_per_save=args.save_gen_path_bs)
  1. generator.gen generates with the vLLM model llm using sampling parameters sampling_params on query data points query_dps until every data point meets the stopping criteria dp_stop_criteria.
  2. Samples are generated in batch and evaluated with batch_evaluator if specified.
  3. Generated samples are saved to save_path.

For a more detailed usage example, please refer to our generation script for DART-Math.

dart_math.eval – Elaborate (Mathematical) Evaluation

EvaluatorMath implements an elaborate evaluation pipeline for mathematical reasoning tasks.

from dart_math.eval import EvaluatorMath

math_evaluator = EvaluatorMath()

For more details or more general interfaces, please refer to the document of dart_math.eval.

Accurately Extracting Answer Strings

EvaluatorMath can:

  1. extract short answers from long responses rather accurately
  2. and normalize into a mathematical expression.
# MATH-style boxed answer
math_evaluator.extract_ans("Therefore, $1+1=\\boxed{2}$.")
'2'
# Answer around "answer"
math_evaluator.extract_ans(
    "Both $1$ and $11$ divide $11,$ so $\\boxed{11}=2$, and since $1,$ $2,$ $4,$ $5,$ $10,$ and $20$ divide $20,$ then $\\boxed{20}=6$. The inner expression, $\\boxed{11}\\times\\boxed{20}=2\\times6=12$. Finally, $\\boxed{12}=6$ because $1,$ $2,$ $3,$ $4,$ $6,$ and $12$ divide $12.$\n\nTherefore, $6$ is our answer. Please note that we have not boxed the correct answer as we normally do, as that would be especially confusing for this problem."
)
'6'
# Use the last number by default
math_evaluator.extract_ans(
    'First, we need to count the total number of letters in the word "CIRCLE". There are 6 letters.\n\nNext, we need to count the number of distinct letters. There are 6 distinct letters in the word "CIRCLE": C, I, R, L, E, and G.\n\nNow, let\'s consider the arrangements of the distinct letters. The number of ways to arrange n distinct items is n factorial (n!). So, we have 6! = 6 × 5 × 4 × 3 × 2 × 1 = 720 ways to arrange the distinct letters.\n\nHowever, the word "CIRCLE" has one letter that repeats (the letter \'C\' repeats twice). We have over-counted the number of distinct arrangements by including arrangements that are just rotations of each other (for example, "CIRCLE" and "LCIRCE" are considered different arrangements here, but they are the same word when read).\n\nTo correct for this, we divide the total number of arrangements by the number of ways to arrange the repeated letters. The number of ways to arrange 2 identical items is 2! = 2 × 1 = 2. So, we divide the total number of arrangements by 2 to get the correct number of distinct arrangements.\n\nTherefore, the number of ways to arrange the letters of the word "CIRCLE" is 720 ÷ 2 = 360.'
)
# More cases ...
'360'
# Normalize fraction
math_evaluator.extract_ans("The answer is 1/2")
'\\frac{1}{2}'
# Normalize pmatrix
math_evaluator.extract_ans(
    "The answer is \\begin{pmatrix} 3 \\\\ \\frac{\\pi}{2} \\end{pmatrix}"
)
# More cases ...
'\\begin{array}3\\\\frac{\\pi}{2}\\end{array}'

Correctly Processing Various Mathematical Objects / Special Text

EvaluatorMath, based on regular expressions and SymPy symbolic calculation, is able to correctly process

  • most mathematical objects such as matrices (vectors), intervals, symbols besides numbers,
  • as well as some special texts like bool expressions, dates and times.
math_evaluator.eq("x+y", "y+x") == True  # Expression
True
math_evaluator.eq("\\frac{1}{2}", "0.5") == True  # LaTeX
True
math_evaluator.eq(
    "\\begin{array}1\\\\2\\end{array}",
    "1,2",
)  # Matrix (Vector)
True
math_evaluator.eq("{1,2}", "{2,1}", compare_sets=True)  # Set
True
math_evaluator.eq("no", "false")  # Bool
# More mathematical objects and special texts ...
True

Batch Evaluation with Timeout

SymPy symbolic calculation causes risks of ex-long evaluation time.

To address this, we implement EvaluatorMathBatch to evaluate in batch with timeout but still efficiently (based on asyncio coroutines instead of multiprocessing in previous implementations).

answers, corrects = math_evalutor.batch_eval(resp_samples)

More Details

Please browse along the sidebar for more details of diffrent modules.

Back to top