Skip to content

Add multithreading to table lookup (#5849)#5849

Open
ShuyangLiu wants to merge 1 commit into
pytorch:mainfrom
ShuyangLiu:export-D102867249
Open

Add multithreading to table lookup (#5849)#5849
ShuyangLiu wants to merge 1 commit into
pytorch:mainfrom
ShuyangLiu:export-D102867249

Conversation

@ShuyangLiu

@ShuyangLiu ShuyangLiu commented Jun 8, 2026

Copy link
Copy Markdown

Summary:

X-link: https://github.com/facebookresearch/FBGEMM/pull/2767

What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:

  • TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
  • TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
    dynamic scheduling (good load balancing when table sizes are skewed).

Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local weights pointer, int64 loop index).

Design

  • Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
    at::parallel_for, so TBE gets its own thread count independent of the global
    intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
  • Thread count is read once from the env var and cached (thread-safe static
    init), and clamped to the number of tables.

Correctness

  • Removed the function-scoped weights_acc pointer, which every iteration
    overwrote — a data race once the loop is parallel. Replaced with a loop-local
    pointer (identical pointer value). Every other variable in the loop body is
    already loop-local, and each table writes a disjoint output slice
    (output_acc + D_start), so results are bitwise-identical to the sequential
    path.
  • The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
    the threaded path it — like any other throw from the loop body (kernel
    errors, at::arange checks) — is captured (first one wins) and rethrown after
    the join, so no exception escapes the OpenMP region.
  • Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
    grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
    at::arange in the nobag path) run with the correct thread-local context.

Verification

  • Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
    target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
    references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
    even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
    supplies it).
  • nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
    test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8 | 1 | 1,436 | 4.38 | --- | --- |
| 8 | 2 | 1,042 | 6.03 | 1.38x | 69% |
| 8 | 4 | 844 | 7.46 | 1.70x | 43% |
| 8 | 8 | 773 | 8.14 | 1.86x | 23% |
| 32 | 1 | 4,797 | 5.25 | --- | --- |
| 32 | 2 | 2,830 | 8.90 | 1.69x | 85% |
| 32 | 4 | 2,003 | 12.60 | 2.40x | 60% |
| 32 | 8 | 1,633 | 15.49 | 2.94x | 37% |
| 64 | 1 | 10,132 | 4.97 | --- | --- |
| 64 | 2 | 6,767 | 7.44 | 1.50x | 75% |
| 64 | 4 | 4,864 | 10.35 | 2.08x | 52% |
| 64 | 8 | 4,033 | 12.50 | 2.51x | 31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: helloguo, q10

Differential Revision: D102867249

@meta-codesync

meta-codesync Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

@ShuyangLiu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102867249.

@ShuyangLiu ShuyangLiu force-pushed the export-D102867249 branch from 21bca1a to b863b50 Compare June 8, 2026 22:50
@meta-codesync meta-codesync Bot changed the title Add multithreading to table lookup Add multithreading to table lookup (#5849) Jun 8, 2026
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 8, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: q10

Differential Revision: D102867249
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 8, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: q10

Differential Revision: D102867249
@ShuyangLiu ShuyangLiu force-pushed the export-D102867249 branch 2 times, most recently from 93771b1 to 6cac33b Compare June 10, 2026 19:28
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 10, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: q10

Differential Revision: D102867249
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 10, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: q10

Differential Revision: D102867249
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 12, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: helloguo, q10

Differential Revision: D102867249
@ShuyangLiu ShuyangLiu force-pushed the export-D102867249 branch 2 times, most recently from c028af0 to 6b10e0f Compare June 12, 2026 18:03
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 12, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: helloguo, q10

Differential Revision: D102867249
ShuyangLiu pushed a commit to ShuyangLiu/FBGEMM-1 that referenced this pull request Jun 15, 2026
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: helloguo, q10

Differential Revision: D102867249
Summary:

X-link: facebookresearch/FBGEMM#2767

## What

Parallelizes the per-table loop in the CPU TBE forward kernel
(IntNBitTableBatchedEmbeddingBagsCodegen::forward) across tables. The
per-table loop is embarrassingly parallel — each table reads its own weight
slice and writes a disjoint slice of the output — so fanning tables out across
threads gives near-linear speedup on table-heavy inference models.

Gated by the TBE_TABLE_THREADS env var:
- TBE_TABLE_THREADS=1 (default): unchanged sequential behavior.
- TBE_TABLE_THREADS=N>1: tables are distributed over N OpenMP threads with
  dynamic scheduling (good load balancing when table sizes are skewed).

## Default behavior is unchanged

When TBE_TABLE_THREADS<=1 the helper takes an early return and runs the loop
body sequentially with no try/catch wrapper and no thread-local-state guard, so
the default path is functionally identical to the pre-change code: same
iteration order, the DEVICE-placement TORCH_CHECK in its original per-table
position, same error semantics, and the same generated machine code for the
body. The only always-on changes are mechanical and behavior-preserving
(loop-local `weights` pointer, int64 loop index).

## Design

- Raw OpenMP (#pragma omp parallel + omp for, schedule(dynamic)) rather than
  at::parallel_for, so TBE gets its own thread count independent of the global
  intra-op pool / OMP_NUM_THREADS (predictors run with OMP_NUM_THREADS=1).
- Thread count is read once from the env var and cached (thread-safe static
  init), and clamped to the number of tables.

## Correctness

- Removed the function-scoped `weights_acc` pointer, which every iteration
  overwrote — a data race once the loop is parallel. Replaced with a loop-local
  pointer (identical pointer value). Every other variable in the loop body is
  already loop-local, and each table writes a disjoint output slice
  (output_acc + D_start), so results are bitwise-identical to the sequential
  path.
- The per-table DEVICE-placement TORCH_CHECK stays in its original position. In
  the threaded path it — like any other throw from the loop body (kernel
  errors, at::arange checks) — is captured (first one wins) and rethrown after
  the join, so no exception escapes the OpenMP region.
- Worker threads restore the caller's at::ThreadLocalState (dispatch keys,
  grad/inference mode, autocast, ...), so ATen calls inside the loop (e.g.
  at::arange in the nobag path) run with the correct thread-local context.

## Verification

- Builds clean (mode/opt); confirmed OpenMP is actually enabled for this
  target by inspecting the compiled object — gen_*_codegen_cpu.cpp.pic.o
  references __kmpc_fork_call / omp_get_num_threads (the pragma is NOT a no-op,
  even though no -fopenmp appears in the TARGETS; the fbcode default toolchain
  supplies it).
- nbit_forward CPU unit tests pass with TBE_TABLE_THREADS=4, including
  test_nbit_forward_cpu_with_table_sharing (non-monotonic weights_offsets).

## Benchmark (INT4, B=512, D=128, E=100K, L=20, FP16/SUM, iters=100, 3-run avg)

| Tables | Threads | Avg us | BW (GB/s) | Speedup | Efficiency |
| 8      | 1       |  1,436 |      4.38 |     --- |        --- |
| 8      | 2       |  1,042 |      6.03 |   1.38x |        69% |
| 8      | 4       |    844 |      7.46 |   1.70x |        43% |
| 8      | 8       |    773 |      8.14 |   1.86x |        23% |
| 32     | 1       |  4,797 |      5.25 |     --- |        --- |
| 32     | 2       |  2,830 |      8.90 |   1.69x |        85% |
| 32     | 4       |  2,003 |     12.60 |   2.40x |        60% |
| 32     | 8       |  1,633 |     15.49 |   2.94x |        37% |
| 64     | 1       | 10,132 |      4.97 |     --- |        --- |
| 64     | 2       |  6,767 |      7.44 |   1.50x |        75% |
| 64     | 4       |  4,864 |     10.35 |   2.08x |        52% |
| 64     | 8       |  4,033 |     12.50 |   2.51x |        31% |

2 threads is the efficiency sweet spot (69-85%); efficiency falls off at higher
counts due to fixed fork/join overhead per call. This matches the production
recommendation TBE_TABLE_THREADS=2 (+7.3% QPS measured in ICE). Microbenchmark
kernel speedups overpredict end-to-end QPS (Amdahl: CPU TBE is a small fraction
of total inference latency).

Reviewed By: helloguo, q10

Differential Revision: D102867249
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant