-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
126 lines (110 loc) · 3.98 KB
/
Copy pathmodel.py
File metadata and controls
126 lines (110 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import time
import openai
from openai import OpenAI
from transformers import AutoTokenizer
import torch
import transformers
from prompts import identity
from pprint import pprint
def get_model(args):
model_name, temperature = args.model, args.temperature
if 'gpt' in model_name:
# # for azure api
# model = GPT(model_name, temperature)
# for direct openai api
model = GPT(args.api_key, model_name, temperature)
return model
elif 'llama' in model_name:
return LLaMA(model_name, temperature)
class Model(object):
def __init__(self):
self.post_process_fn = identity
def set_post_process_fn(self, post_process_fn):
self.post_process_fn = post_process_fn
class GPT(Model):
def __init__(self, api_key, model_name, temperature):
super().__init__()
# Passing None lets the SDK fall back to the OPENAI_API_KEY environment
# variable automatically, so the key never needs to be in any script.
self.client = OpenAI(api_key=api_key or None)
self.model_name = model_name
self.temperature = temperature
self.badrequest_count = 0
def get_response(self, **kwargs):
try:
return self.client.chat.completions.create(**kwargs)
except openai.APIConnectionError:
print('APIConnectionError')
time.sleep(30)
return self.get_response(**kwargs)
except openai.RateLimitError:
print('RateLimitError')
time.sleep(10)
return self.get_response(**kwargs)
except openai.APITimeoutError:
print('APITimeoutError')
time.sleep(30)
return self.get_response(**kwargs)
except openai.BadRequestError:
self.badrequest_count += 1
print(f'BadRequestError (count: {self.badrequest_count})')
return None
def forward(self, head, prompts):
messages = [
{"role": "system", "content": head}
]
info = {}
for i, prompt in enumerate(prompts):
messages.append(
{"role": "user", "content": prompt}
)
response = self.get_response(
model=self.model_name,
messages=messages,
temperature=self.temperature,
)
if response is None:
info['response'] = None
info['message'] = None
return None, info
else:
messages.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
info = dict(response.usage) # completion_tokens, prompt_tokens, total_tokens
info['response'] = messages[-1]["content"]
info['message'] = messages
# print("response: ", info['response'])
return self.post_process_fn(info['response']), info
class LLaMA(Model):
def __init__(self, model_name, temperature):
super().__init__()
self.model_name = model_name
self.temperature = temperature
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "[PAD]"
tokenizer.padding_side = "left"
self.tokenizer = tokenizer
self.pipeline = transformers.pipeline(
"text-generation",
model=model_name,
torch_dtype=torch.float16,
device_map="auto",
tokenizer=tokenizer,
temperature=temperature
)
def forward(self, head, prompts):
prompt = prompts[0]
sequences = self.pipeline(
prompt,
do_sample=False,
top_k=1,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
)
response = sequences[0]['generated_text'] # str
info = {
'message': prompt,
'response': response
}
return self.post_process_fn(info['response']), info