-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprofile_CUDA_run.py
More file actions
70 lines (50 loc) · 2.12 KB
/
profile_CUDA_run.py
File metadata and controls
70 lines (50 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#trying to profile a simple GPT model using torch profiler
#here we will try to use the classes, create a model and run a forward pass
import torch
from models.gpt import GPT
from tokenizer.bpe import BPE
from torch.profiler import profile, ProfilerActivity, record_function
if __name__ == "__main__":
# Define model parameters
vocab_size = 512
embed_size = 512
num_blocks = 2
heads = 4
# Initialize the GPT model
model = GPT(vocab_size, embed_size, num_blocks, heads)
# Create a simple tokenizer
tokenizer = BPE()
tokenizer.load_trained('bpe_tokenizer.txt')
# Sample input text
input_text = "hello, i am your friendly neighborhood"
# Tokenize the input text
input_ids = tokenizer.encode(input_text)
input_tensor = torch.tensor([input_ids]) # Shape: (1, seq_length)
# Run a forward pass through the model
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
device = "cuda"
activities += [ProfilerActivity.CUDA]
elif torch.xpu.is_available():
device = "xpu"
activities += [ProfilerActivity.XPU]
else:
print(
"Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices"
)
import sys
sys.exit(0)
sort_by_keyword = device + "_time_total"
with profile(activities=activities, record_shapes=True) as prof:
with record_function("model_inference"):
output_logits =model(input_tensor.to(device))
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
print("Input IDs:", input_ids)
print("Output logits shape:", output_logits.shape) # Should be (1, seq_length, vocab_size)
softmax_output = torch.softmax(output_logits, dim=-1)
predicted_tokens = torch.argmax(softmax_output, dim=-1)
print("Predicted token IDs:", predicted_tokens)
# Decode the predicted token IDs back to text
predicted_text = tokenizer.decode(predicted_tokens[0].tolist())
print("Predicted text:", predicted_text)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))