from dart_math.eval import EvaluatorMath
= EvaluatorMath() math_evaluator
Quick Start
dart-math
in several lines of code.
Installation
We recommend using Conda and pip to manage your environment. Run the following commands to setup your environment:
git clone https://github.com/hkust-nlp/dart-math.git && cd dart-math
conda create --name dart-math --yes python=3.11
conda activate dart-math
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
For common users/developers, please just run the following command the install the dart-math
package:
pip install -e "."
For intended contributors, we recommend installing the package with the dev
extras:
pip install -e ".[dev]"
pip install pre-commit
dart_math.train
: Efficient Training Tricks
Accelerating Several Times with Sequence Packing in 4 Lines of Code
Our interfaces can be integrated with the HuggingFace datasets
in 4 lines of code:
from dart_math.train import monkey_patch4pack, make_supervised_dset
# ...
monkey_patch4pack(model)= make_supervised_dset(tokenizer=tokenizer, data_path=data_args.data_path, pack_len=training_args.model_max_length, query_field=data_args.query_field,, resp_field=data_args.resp_field, prompt_template=data_args.prompt_template)
pack_dset = Trainer(model=model, tokenizer=tokenizer, train_dataset=pack_dset) trainer
monkey_patch4pack
would monkey-patch the model’s _get_unpad_data
method.
make_supervised_dset
would
- load, tokenize and cache the dataset;
- pack the data points into computation sequences.
For a more detailed usage example, please refer to our training script for DART-Math.
Besides, for general datasets objects that with the form [{"input_ids": [...], "labels": [...], "attention_mask"}: [...]}, ...]
, you can use PackedDataset
to wrap it to apply sequence packing:
from dart_math.train import PackedDataset
# ...
= PackedDataset(dataset=dset, tokenizer=tokenizer, pack_len=4096) dset
For more details or more general interfaces, please refer to the document of dart_math.train
.
dart_math.gen
– Efficient Generation with Flexible Stopping Criteria
Difficulty-Aware Rejection Sampling (with Code Execution) in 5 Lines of Code
from dart_math.data import load_query_dps
from dart_math.gen import gen, is_dp_dars_finished
from dart_math.eval import EvaluatorMathBatch
# ...
= Generator(llm, sampling_params, resp_sample_cls=RespSampleVLLM, batch_evaluator=(EvaluatorMathBatch() if not args.gen_only else None), code_exec_cfg=CodeExecCfg.load_from_id_or_path(args.code_exec_cfg) if args.code_exec_cfg else None)
generator =query_dps, dp_stop_criteria=is_dp_dars_finished, save_path=args.gen_save_path, n_paths_per_save=args.save_gen_path_bs) generator.gen(query_dps
generator.gen
generates with the vLLM modelllm
using sampling parameterssampling_params
on query data pointsquery_dps
until every data point meets the stopping criteriadp_stop_criteria
.- Samples are generated in batch and evaluated with
batch_evaluator
if specified. - Generated samples are saved to
save_path
.
For a more detailed usage example, please refer to our generation script for DART-Math.
dart_math.eval
– Elaborate (Mathematical) Evaluation
EvaluatorMath
implements an elaborate evaluation pipeline for mathematical reasoning tasks.
For more details or more general interfaces, please refer to the document of dart_math.eval
.
Accurately Extracting Answer Strings
EvaluatorMath
can:
- extract short answers from long responses rather accurately
- and normalize into a mathematical expression.
# MATH-style boxed answer
"Therefore, $1+1=\\boxed{2}$.") math_evaluator.extract_ans(
'2'
# Answer around "answer"
math_evaluator.extract_ans("Both $1$ and $11$ divide $11,$ so $\\boxed{11}=2$, and since $1,$ $2,$ $4,$ $5,$ $10,$ and $20$ divide $20,$ then $\\boxed{20}=6$. The inner expression, $\\boxed{11}\\times\\boxed{20}=2\\times6=12$. Finally, $\\boxed{12}=6$ because $1,$ $2,$ $3,$ $4,$ $6,$ and $12$ divide $12.$\n\nTherefore, $6$ is our answer. Please note that we have not boxed the correct answer as we normally do, as that would be especially confusing for this problem."
)
'6'
# Use the last number by default
math_evaluator.extract_ans('First, we need to count the total number of letters in the word "CIRCLE". There are 6 letters.\n\nNext, we need to count the number of distinct letters. There are 6 distinct letters in the word "CIRCLE": C, I, R, L, E, and G.\n\nNow, let\'s consider the arrangements of the distinct letters. The number of ways to arrange n distinct items is n factorial (n!). So, we have 6! = 6 × 5 × 4 × 3 × 2 × 1 = 720 ways to arrange the distinct letters.\n\nHowever, the word "CIRCLE" has one letter that repeats (the letter \'C\' repeats twice). We have over-counted the number of distinct arrangements by including arrangements that are just rotations of each other (for example, "CIRCLE" and "LCIRCE" are considered different arrangements here, but they are the same word when read).\n\nTo correct for this, we divide the total number of arrangements by the number of ways to arrange the repeated letters. The number of ways to arrange 2 identical items is 2! = 2 × 1 = 2. So, we divide the total number of arrangements by 2 to get the correct number of distinct arrangements.\n\nTherefore, the number of ways to arrange the letters of the word "CIRCLE" is 720 ÷ 2 = 360.'
)# More cases ...
'360'
# Normalize fraction
"The answer is 1/2") math_evaluator.extract_ans(
'\\frac{1}{2}'
# Normalize pmatrix
math_evaluator.extract_ans("The answer is \\begin{pmatrix} 3 \\\\ \\frac{\\pi}{2} \\end{pmatrix}"
)# More cases ...
'\\begin{array}3\\\\frac{\\pi}{2}\\end{array}'
Correctly Processing Various Mathematical Objects / Special Text
EvaluatorMath
, based on regular expressions and SymPy symbolic calculation, is able to correctly process
- most mathematical objects such as matrices (vectors), intervals, symbols besides numbers,
- as well as some special texts like bool expressions, dates and times.
"x+y", "y+x") == True # Expression math_evaluator.eq(
True
"\\frac{1}{2}", "0.5") == True # LaTeX math_evaluator.eq(
True
math_evaluator.eq("\\begin{array}1\\\\2\\end{array}",
"1,2",
# Matrix (Vector) )
True
"{1,2}", "{2,1}", compare_sets=True) # Set math_evaluator.eq(
True
"no", "false") # Bool
math_evaluator.eq(# More mathematical objects and special texts ...
True
Batch Evaluation with Timeout
SymPy symbolic calculation causes risks of ex-long evaluation time.
To address this, we implement EvaluatorMathBatch
to evaluate in batch with timeout but still efficiently (based on asyncio
coroutines instead of multiprocessing
in previous implementations).
= math_evalutor.batch_eval(resp_samples) answers, corrects
More Details
Please browse along the sidebar for more details of diffrent modules.