Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/superbench-config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark.
squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
llama2-7b | llama2-13b | llama2-70b |
mixtral-8x7b | mixtral-8x22b ]
```
* default value: `[ ]`

Expand Down
1 change: 1 addition & 0 deletions docs/user-tutorial/benchmarks/model-benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* MoE: mixtral-8x7b, mixtral-8x22b
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
Expand Down
38 changes: 38 additions & 0 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel

if MixtralBenchmarkModel is not None:
from transformers import MixtralConfig


class torch2onnxExporter():
Expand Down Expand Up @@ -122,6 +126,40 @@ def __init__(self):
self.num_classes,
),
}

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
self.benchmark_models.update(
{
'mixtral-8x7b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
),
self.num_classes,
),
'mixtral-8x22b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
),
self.num_classes,
),
}
)

self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)

Expand Down
10 changes: 9 additions & 1 deletion superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,13 @@
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT
from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral

__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']
__all__ = [
'ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama',
'PytorchMixtral'
]

if PytorchMixtral is not None:
__all__.append('PytorchMixtral')
33 changes: 33 additions & 0 deletions superbench/benchmarks/model_benchmarks/pytorch_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch Mixtral model gate."""

import sys
from superbench.benchmarks import BenchmarkRegistry

if sys.version_info < (3, 8):
MixtralBenchmarkModel = None
PytorchMixtral = None
else:
from superbench.benchmarks.model_benchmarks.pytorch_mixtral_impl import MixtralBenchmarkModel, PytorchMixtral

# Register Mixtral benchmark with 8x7b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x7b',
PytorchMixtral,
parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 \
--num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02'
)

# Register Mixtral benchmark with 8x22b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x22b',
PytorchMixtral,
parameters='--hidden_size=6144 --num_hidden_layers=56 --num_attention_heads=48 --intermediate_size=16384 \
--num_key_value_heads=8 --max_position_embeddings=65536 --router_aux_loss_coef=0.001'
)

__all__ = ['MixtralBenchmarkModel', 'PytorchMixtral']
254 changes: 254 additions & 0 deletions superbench/benchmarks/model_benchmarks/pytorch_mixtral_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch Mixtral model implementation."""

import torch
from transformers import MixtralModel, MixtralConfig
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
te = None

from superbench.common.utils import logger
from superbench.benchmarks import Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class MixtralBenchmarkModel(torch.nn.Module):
"""The Mixtral model for benchmarking."""
def __init__(self, config, num_classes):
"""Constructor.

Args:
config (MixtralConfig): Configurations of Mixtral model.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._Mixtral = MixtralModel(config)
self._linear = torch.nn.Linear(config.hidden_size, num_classes)

def forward(self, input):
"""Forward propagation function.

Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).

Return:
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
outputs = self._Mixtral(input)
result = self._linear(outputs[0])
return result


class PytorchMixtral(PytorchBase):
"""The Mixtral benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.

Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._config = None
self._fp8_recipe = None
self._supported_precision = [
Precision.FLOAT32,
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()

def add_parser_arguments(self):
"""Add the Mixtral-specified arguments.

Mixtral model reference: https://huggingface.co/docs/transformers/model_doc/Mixtral
"""
super().add_parser_arguments()

self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
self._parser.add_argument('--hidden_size', type=int, default=4096, required=False, help='Hidden size.')
self._parser.add_argument(
'--num_hidden_layers', type=int, default=32, required=False, help='The number of hidden layers.'
)
self._parser.add_argument(
'--num_attention_heads', type=int, default=32, required=False, help='The number of attention heads.'
)
self._parser.add_argument(
'--intermediate_size',
type=int,
default=14336,
required=False,
help='Dimension of the MLP representations.'
)
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
self._parser.add_argument(
'--num_key_value_heads',
type=int,
default=8,
required=False,
help='The number of key_value heads that should be used to implement Grouped Query Attention.'
)
self._parser.add_argument(
'--max_position_embeddings',
type=int,
default=None,
required=False,
help='Maximum sequence length that Mixtral supports'
)
self._parser.add_argument(
'--router_aux_loss_coef',
type=float,
default=0.001,
required=False,
help='The aux loss factor for the total loss.'
)

def _generate_dataset(self):
"""Generate dataset for benchmarking according to shape info.

Return:
True if dataset is created successfully.
"""
self._dataset = TorchRandomDataset(
[self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
return False

return True

def _create_model(self, precision):
"""Construct the model for benchmarking.

Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
self._config = MixtralConfig(
hidden_size=self._args.hidden_size,
num_hidden_layers=self._args.num_hidden_layers,
num_attention_heads=self._args.num_attention_heads,
num_key_value_heads=self._args.num_key_value_heads,
intermediate_size=self._args.intermediate_size,
max_position_embeddings=self._args.max_position_embeddings,
router_aux_loss_coef=self._args.router_aux_loss_coef,
)

enable_fp8 = precision.name.startswith('FP8_')
if enable_fp8 and te is None:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: Cannot find transformer_engine.'
)
return False
if enable_fp8 and not self._gpu_available:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: FP8 is only supported on GPU.'
)
return False

try:
self._model = MixtralBenchmarkModel(self._config, self._args.num_classes)
if enable_fp8:
self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16,
amax_compute_algo='max',
)
self._to_te_model(self._model.to(dtype=torch.float16))
else:
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
self._model = self._model.cuda()
except Exception as e:
logger.error(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
self._name, precision, str(e)
)
)
return False

self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
if self._gpu_available:
self._target = self._target.cuda()

return True

def _train_step(self, precision):
"""Define the training process.

Args:
precision (Precision): precision of model and input data, such as float32, float16.

Return:
The step-time list of every training step.
"""
duration = []
curr_step = 0
check_frequency = 100
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
self._optimizer.zero_grad()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
output = self._model(sample)
else:
output = self._model(sample)
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
loss.backward()
self._optimizer.step()
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency):
return duration

def _inference_step(self, precision):
"""Define the inference process.

Args:
precision (Precision): precision of model and input data,
such as float32, float16.

Return:
The latency list of every inference operation.
"""
duration = []
curr_step = 0
with torch.no_grad():
self._model.eval()
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
self._model(sample)
else:
self._model(sample)
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end):
return duration
Loading
Loading