API reference
This page documents the main public classes and training helpers exposed by
mcreweight.
Pipelines
- mcreweight.core.run_reweighting_pipeline(args, plotdir, weightsdir)
Main function to run the reweighting pipeline.
- mcreweight.core.apply_weights_pipeline(args, plotdir, weightsdir, out_weightsdir)
Main function to apply weights to the data using the trained model.
Training helpers
- mcreweight.train.train_and_test(mc, data, mcweights, sweights, columns, test_size)
Split the data into training and testing sets.
- Args:
mc (pd.DataFrame): MC data. data (pd.DataFrame): Data to reweight to. mcweights (np.ndarray): Weights for the MC data. sweights (np.ndarray): Sweights for the data. columns (list): List of column names to use for training. test_size (float): Proportion of the dataset to include in the test split.
- Returns:
sample (dict): Dictionary containing the training and testing splits for MC and data, along with their weights.
- mcreweight.train.gbreweight(args, sample, columns, study, weightsdir)
Train a GradientBoostingClassifier reweighter and predict weights for the test MC data.
- Args:
args: Command-line arguments containing configuration options. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. study (optuna.study.Study): Optuna study object containing the best hyperparameters. weightsdir (str): Directory to save the model and weights.
- mcreweight.train.onnxgbreweight(args, sample, columns, weightsdir, study=None)
Train an ONNX-exportable GB reweighter that mirrors hep_ml’s signed-weight loss.
- mcreweight.train.xgbreweight(args, sample, columns, weightsdir, study=None)
Train an iterative XGBoost reweighter using best hyperparameters from Optuna, then predict weights on mc_test.
- Args:
args: Command-line arguments containing configuration options. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. study (optuna.study.Study): Optuna study object containing the best hyperparameters. weightsdir (str): Directory to save the model and weights.
- mcreweight.train.nnreweight(args, sample, columns, weightsdir, study=None)
Train an iterative NN reweighter using best hyperparameters from Optuna, then predict weights on mc_test.
- Args:
args: Command-line arguments containing configuration options. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. study (optuna.study.Study): Optuna study object containing the best hyperparameters. weightsdir (str): Directory to save the model and weights.
- Returns:
nn (ONNXINNReweighter): Trained NN reweighter model. new_mc_weights (np.ndarray): Predicted weights for the MC data using the trained NN reweighter.
- mcreweight.train.binning_reweight(
- args,
- sample,
- columns,
- n_bins,
- n_neighs,
- weightsdir,
Train an ONNXBinsReweighter and predict weights for the test MC data. Mirrors the behavior of the original BinsReweighter version, with ONNX support.
- Args:
args: Command-line arguments containing configuration options. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. n_bins (int): Number of bins to use for the reweighter. n_neighs (float): Number of neighbors for smoothing (ignored in this function). weightsdir (str): Directory to save the model and weights.
- mcreweight.train.gbfolding(args, gb, sample, columns, n_folds, weightsdir)
Train a folding reweighter using the base GB model, then predict weights on mc_test.
- Args:
args: Command-line arguments containing configuration options. gb (GBReweighter): Base GB reweighter model to use for folding. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. n_folds (int): Number of folds for k-folding. weightsdir (str): Directory to save the model and weights.
- mcreweight.train.onnxfolding(args, onnxgb, sample, columns, n_folds, weightsdir)
Train an ONNXFoldingReweighter using the ONNXGB base model.
- mcreweight.train.xgbfolding(
- args,
- xgb,
- sample,
- columns,
- n_folds,
- weightsdir,
- n_iterations=15,
Train an ONNXIXGBFoldingReweighter using the base XGB model, then predict weights on mc_test.
- Args:
args: Command-line arguments containing configuration options. xgb (ONNXIXGBReweighter): Base XGB reweighter model to use for folding. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. n_folds (int): Number of folds for k-folding. weightsdir (str): Directory to save the model and weights. n_iterations (int): Number of iterations for the base XGB model.
- mcreweight.train.nnfolding(
- args,
- nn,
- sample,
- columns,
- n_folds,
- weightsdir,
- n_iterations=5,
Train an ONNXINNFoldingReweighter using the base NN model, then predict weights on mc_test.
- Args:
args: Command-line arguments containing configuration options. nn (ONNXINNReweighter): Base NN reweighter model to use for folding. sample (dict): Dictionary containing training and test data along with their weights. columns (list): List of column names to use for training. n_folds (int): Number of folds for k-folding. weightsdir (str): Directory to save the model and weights. n_iterations (int): Number of iterations for the base NN model.
Main reweighter classes
- class mcreweight.models.onnxreweighter.ONNXGBReweighter(
- transform=None,
- verbosity=1,
- n_estimators=40,
- learning_rate=0.2,
- max_depth=3,
- min_samples_leaf=200,
- loss_regularization=5.0,
- subsample=1.0,
- min_samples_split=2,
- max_features=None,
- max_leaf_nodes=None,
- splitter='best',
- update_tree=True,
- random_state=42,
- store_dir=None,
- eps=1e-06,
ONNX-exportable implementation of hep_ml’s GBReweighter logic.
- This mirrors the signed-weight handling of hep_ml’s ReweightLossFunction:
tree targets are class/sign-based residuals
tree fit weights use abs(normalized signed weights)
leaf updates are log(target_weight + reg) - log(original_weight + reg)
- class mcreweight.models.onnxreweighter.ONNXIXGBReweighter(
- transform=None,
- verbosity=1,
- n_iterations=30,
- mixing_learning_rate=0.05,
- clip_delta=2.0,
- max_log_weight=3.0,
- mixing_subsample=1.0,
- random_state=42,
- store_dir=None,
- reweight_validation_fraction=0.2,
- reweight_early_stopping_rounds=5,
- reweight_metric_every=1,
- **xgb_params,
- get_params()
Return the parameters of the reweighter, including both the ONNXReweighterMixin parameters and the XGBoost parameters. This can be useful for logging, debugging, or reproducing the model configuration.
- set_params(**params)
Set the parameters of the reweighter, allowing updates to both the ONNXReweighterMixin parameters and the XGBoost parameters. This can be useful for tuning the model configuration after initialization.
- Args:
params: Arbitrary keyword arguments corresponding to the parameters to be updated.
- class mcreweight.models.onnxreweighter.ONNXINNReweighter(
- transform=None,
- verbosity=1,
- n_iterations=30,
- mixing_learning_rate=0.1,
- clip_delta=3.0,
- max_log_weight=3.0,
- mixing_subsample=1.0,
- random_state=42,
- store_dir=None,
- reweight_validation_fraction=0.2,
- reweight_early_stopping_rounds=5,
- reweight_metric_every=1,
- **nn_params,
- get_params()
Return the parameters of the reweighter, including both the ONNXReweighterMixin parameters and the Neural Network parameters. This can be useful for logging, debugging, or reproducing the model configuration.
- set_params(**params)
Set the parameters of the reweighter, allowing updates to both the ONNXReweighterMixin parameters and the Neural Network parameters. This can be useful for tuning the model configuration after initialization.
- Args:
params: Arbitrary keyword arguments corresponding to the parameters to be updated.
- class mcreweight.models.onnxreweighter.ONNXBinsReweighter(
- transform=None,
- verbosity=1,
- n_bins=50,
- n_neighs=2,
- min_in_bin=1.0,
- eps=1e-06,
N-dimensional histogram reweighter with simple neighbor smoothing.
It relies on the common helpers provided by
BaseONNXReweighterfor feature transformations, input validation, and per-class weight normalization.- fit(original, target, ow=None, tw=None)
Fit the reweighter by computing the ratio of N-dimensional histograms of original and target datasets.
- Args:
original (array-like): Original dataset (e.g., MC samples). target (array-like): Target dataset (e.g., real data samples). ow (array-like, optional): Sample weights for the original dataset. If None, uniform weights are used. tw (array-like, optional): Sample weights for the target dataset. If None, uniform weights are used.
- predict_weights(X, ow=None)
Predict reweighting factors for the input samples by looking up the ratio in the corresponding histogram bin.
- Args:
X (array-like): Input samples for which to predict weights. ow (array-like, optional): Original weights for the input samples. If None, uniform weights are assumed.
- Returns:
np.ndarray: Predicted reweighting factors for the input samples.
- save(prefix)
Save the histogram edges and ratio to disk, along with meta information.
- Args:
- prefix (str): Prefix for the saved files. The edges and ratio are
saved to
<prefix>_edges.npyand<prefix>_ratio.npy.
Folding classes
- class mcreweight.models.onnxfolding.ONNXFoldingReweighter(
- n_folds=5,
- shuffle=True,
- random_state=42,
- transform=None,
- verbosity=1,
- **gb_params,
k-folding ensemble of ONNXGBReweighter models.
- class mcreweight.models.onnxfolding.ONNXIXGBFoldingReweighter(
- n_folds=5,
- shuffle=True,
- random_state=42,
- transform=None,
- verbosity=1,
- **xgb_params,
k-folding ensemble of ONNXIXGBReweighter models.
- class mcreweight.models.onnxfolding.ONNXINNFoldingReweighter(
- n_folds=5,
- shuffle=True,
- random_state=42,
- transform=None,
- verbosity=1,
- **nn_params,
k-folding ensemble of ONNXINNReweighter models.
Plotting utilities
- mcreweight.utils.plotting_utils.set_lhcb_style(grid=True, size=12, usetex=False)
Set matplotlib plotting style close to “official” LHCb style (TeX Gyre Termes serif font, inward ticks on all sides, minor ticks, light grid).
- mcreweight.utils.plotting_utils.plot_correlation_matrix(
- args,
- df,
- columns,
- weights,
- x_labels,
- title,
- output_file,
Plot a correlation matrix for the given DataFrame columns.
- Args:
args (argparse.Namespace): Command line arguments containing verbosity flag. df (pd.DataFrame): DataFrame containing the data. columns (list): List of column names to include in the correlation matrix. weights (np.ndarray, optional): Weights for the correlation calculation. If None, unweighted correlation is used. x_labels (dict): Mapping of column names to x-axis labels for the plot. title (str): Title of the plot. output_file (str): Path to save the output plot.
- mcreweight.utils.plotting_utils.plot_distributions(
- args,
- mc,
- data,
- mc_weights,
- data_weights,
- columns,
- x_labels,
- output_file,
- transform=None,
- x_edges=None,
- pull_clip=5,
Plot distributions with pull plots, handling MC and Data with different statistics. Histograms are normalized as densities, and pulls are correctly computed.
- Args:
args (argparse.Namespace): Command line arguments containing verbosity flag. mc, data (pd.DataFrame): MC and Data samples. mc_weights, data_weights (np.ndarray): Weights for MC and Data. columns (list): Columns to plot. x_labels (dict): Mapping column names -> x-axis labels. output_file (str): Path to save figure. transform (callable, optional): Transformation function to apply to the data. x_edges (dict, optional): Column -> bin edges mapping. pull_clip (float): Maximum absolute value for pull display.
- mcreweight.utils.plotting_utils.plot_mc_distributions(
- mc,
- original_mc_weights,
- new_mc_weights,
- columns,
- x_labels,
- output_file,
- x_edges=None,
Plot distributions of MC data with weights.
- Args:
mc (pd.DataFrame): MC data. original_mc_weights (np.ndarray): Original weights for the MC data. new_mc_weights (np.ndarray): Weights for the MC data. columns (list): List of column names to plot. x_labels (dict): Dictionary mapping column names to x-axis labels. output_file (str): Path to save the output plot. x_edges (dict, optional): Dictionary mapping column names to bin edges for histogramming.
- mcreweight.utils.plotting_utils.plot_training_throughput(throughput, output_file)
Plot training throughput metrics for each method.
- Args:
throughput (dict): Mapping method -> throughput metric dictionary. output_file (str): Output file path.
- mcreweight.utils.plotting_utils.plot_training_memory(memory_profile, output_file)
Plot training memory metrics for each method.
- Args:
memory_profile (dict): Mapping method -> memory metric dictionary. output_file (str): Output file path.
- mcreweight.utils.plotting_utils.plot_roc_curve(sample, weights, methods, columns, output_file)
Plot ROC curve for the different reweighting methods.
- Args:
sample (dict): Dictionary containing MC and Data samples and their weights. weights (dict): Dictionary containing weights for each method (GB, Folding, XGB, k-Folding, Bins, NN). methods (list): List of methods to include in the plot. columns (list): List of column names to use for plotting. output_file (str): Path to save the output plot.
- Returns:
scores: Dictionaries containing classifier scores for each method.
- mcreweight.utils.plotting_utils.plot_classifier_output(
- scores,
- weights,
- methods,
- output_file,
- min_score=0.0,
- max_score=1.0,
Produce two sets of classifier output plots.
output_file: all methods’ MC score distributions overlaid on a single axes — no Target line — for a quick visual comparison.output_filewith_{method}inserted before the extension: one file per method showing MC (solid) and Target (dashed) from the same per-method classifier, so the KS comparison is self-consistent.
- mcreweight.utils.plotting_utils.plot_weight_distributions(
- weights,
- output_file,
- bins=50,
- xlim=(0, 10),
Plot histograms of weight distributions.
- Args:
weights (dict): Dictionary where keys are labels and values are arrays of weights. output_file (str): Output file path for the plot. bins (int): Number of histogram bins. xlim (tuple or None): Limit for the x-axis, e.g., (0, 5). Default: (0, 10).
- mcreweight.utils.plotting_utils.plot_2d_score_maps(
- sample,
- weights,
- classifier_scores,
- method,
- vars,
- output_file,
- x_labels,
- n_bins=40,
Plot 2D heatmaps of mean classifier score vs all possible pairs of variables.
- Args:
sample (dict): Dictionary containing MC and Data samples. weights (dict): Dictionary of weights for each sample. classifier_scores (dict): Dictionary of classifier scores for each sample. method (str): Reweighter method name. vars (list): List of variables to consider for 2D plots. output_file (str): Path to save the figure. x_labels (dict): Dictionary mapping column names to x-axis labels. n_bins (int): Number of bins for the 2D histogram.
- mcreweight.utils.plotting_utils.plot_feature_importance(
- shap_values,
- feature_names,
- mc,
- x_labels,
- method,
- output_file,
- max_display=None,
Plot SHAP beeswarm (summary) plot for a reweighter.
- Args:
shap_values (dict): Dictionary of SHAP values for each method. Keys should match method names. feature_names (list): Feature names mc (pd.DataFrame): MC sample used for SHAP value computation (for feature values) x_labels (dict): Dictionary mapping column names to x-axis labels. method (str): Reweighter method name. output_file (str): Path to save figure max_display (int): Max number of features to show
- mcreweight.utils.plotting_utils.plot_2d_pull_maps(
- mc,
- data,
- mc_weights,
- data_weights,
- columns,
- x_labels,
- method,
- output_file,
- n_bins=40,
- pull_clip=5,
Plot 2D pull maps for all variable pairs.
The pull in each bin is computed as
(data_density - mc_density) / sqrt(var_data + var_mc).Both MC and Data are normalized to densities so that different statistics are handled correctly.
- Args:
mc, data (pd.DataFrame): MC and Data samples. mc_weights, data_weights (np.ndarray): Weights for MC and Data. columns (list): List of column names to consider for the pull maps. x_labels (dict): Dictionary mapping column names to x-axis labels. method (str): Reweighting method name (for plot title). output_file (str): Path to save figure n_bins (int): Number of bins for the 2D histograms pull_clip (float): Maximum absolute value for pull map clipping