Skip to content

radiradev/flowmatching-bdt

Repository files navigation

Docs

Flow-Matching BDT

A small library for training flow-matching models. Its primary focus is using efficient algorithms for tabular learning — e.g. histogram boosted-decision trees — but it works with any scikit-learn compatible regressor.

Installation

pip install flowmatching-bdt

Quick Start

from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT

data, _ = make_moons(n_samples=500, noise=0.1, random_state=0)
model = FlowMatchingBDT()

# train the model
model.fit(data)

# generate new samples
samples = model.predict(num_samples=500)

Conditional Generation

import numpy as np
from sklearn.datasets import make_moons
from flowmatching_bdt import FlowMatchingBDT

data, labels = make_moons(n_samples=500, noise=0.1, random_state=42)
model = FlowMatchingBDT()

model.fit(data, conditions=labels)

conditions = np.ones(500)
samples = model.predict(num_samples=500, conditions=conditions)

How It Works

Flow matching trains a model to predict a velocity field that transports samples from a simple source distribution (e.g. Gaussian noise) to the data distribution. This implementation:

  1. Discretises the flow into n_flow_steps time steps
  2. Trains one regressor per step to predict the velocity field
  3. At inference, integrates the learned field using Euler steps to generate new samples

Gradient-boosted trees can learn this velocity field just as well as neural networks, while being faster to train on tabular data.

Useful Resources

Citation

This repository started as a reproduction of the following paper:

@inproceedings{jolicoeur2024generating,
  title={Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees},
  author={Jolicoeur-Martineau, Alexia and Fatras, Kilian and Kachman, Tal},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={1288--1296},
  year={2024},
  organization={PMLR}
}

Acknowledgements

This repository is inspired heavily and borrows parts from lucidrains (project structure) and torch-cfm.

About

Implementation of flow matching on tabular data using XGBoost

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages