1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
| class Llama: @staticmethod def build( ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, seed: int = 1, ) -> "Llama": """ 构建一个Llama实例,通过初始化和加载预训练模型。
Args: ckpt_dir (str): 包含检查点文件的目录路径。 tokenizer_path (str): tokenizer文件的路径。 max_seq_len (int): 输入文本的最大序列长度。 max_batch_size (int): 推理的最大批量大小。 model_parallel_size (Optional[int], optional): 模型并行进程的数量。 如果未提供,将从环境中确定。默认为None。
Returns: Llama: 带有加载的模型和分词器的Llama类的实例。
Raises: AssertionError: 如果指定目录中没有检查点文件,或者模型并行大小与检查点文件的数量不匹配。
Note: 此方法会初始化分布式进程组,将设备设置为CUDA,并加载预训练模型和分词器。
""" if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0: sys.stdout = open(os.devnull, "w")
start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read())
model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params, ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer)
def __init__(self, model: Transformer, tokenizer: Tokenizer): self.model = model self.tokenizer = tokenizer
@torch.inference_mode() def generate( self, prompt_tokens: List[List[int]], max_gen_len: int, temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, echo: bool = False, ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: """ 基于提供的提示使用语言生成模型生成文本序列。
Args: prompt_tokens (List[List[int]]): Tokenized prompts的列表, 其中每个提示表示为整数列表。 max_gen_len (int): 生成文本序列的最大长度。 temperature (float, optional): 控制采样中随机性的温度值。默认为0.6。 top_p (float, optional): 核心采样的top-p概率阈值。默认为0.9。 logprobs (bool, optional): 指示是否计算Token对数概率的标志。默认为False。 echo (bool, optional): 指示是否在生成的输出中包括提示Token的标志。默认为False。
Returns: Tuple[List[List[int]], Optional[List[List[float]]]]: 包含生成的Token序列的元组,如果logprobs为True,则包含相应的Token对数概率。
Note: 此方法使用提供的提示作为生成文本的基础。它使用核心采样(nucleus sampling)来产生具有控制随机性的文本。 如果logprobs为True,则为每个生成的Token计算Token对数概率。 """ params = self.model.params bsz = len(prompt_tokens) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= params.max_seq_len total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0 eos_reached = torch.tensor([False] * bsz, device="cuda") input_text_mask = tokens != pad_id if min_prompt_len == total_len: logits = self.model.forward(tokens, prev_pos) token_logprobs = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens, reduction="none", ignore_index=pad_id, )
for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1) next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens[:, prev_pos + 1 : cur_pos + 1], reduction="none", ignore_index=pad_id, ) eos_reached |= (~input_text_mask[:, cur_pos]) & ( next_token == self.tokenizer.eos_id ) prev_pos = cur_pos if all(eos_reached): break
if logprobs: token_logprobs = token_logprobs.tolist() out_tokens, out_logprobs = [], [] for i, toks in enumerate(tokens.tolist()): start = 0 if echo else len(prompt_tokens[i]) toks = toks[start : len(prompt_tokens[i]) + max_gen_len] probs = None if logprobs: probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] if self.tokenizer.eos_id in toks: eos_idx = toks.index(self.tokenizer.eos_id) toks = toks[:eos_idx] probs = probs[:eos_idx] if logprobs else None out_tokens.append(toks) out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None)
def text_completion( self, prompts: List[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, ) -> List[CompletionPrediction]: """ 对一组提示词使用语言生成模型进行文本补完。
Args: prompts (List[str]): 需要补完的文本提示词列表。 temperature (float, optional): 控制采样中随机性的温度值。默认为0.6。 top_p (float, optional): 核心采样的top-p概率阈值。默认为0.9。 max_gen_len (Optional[int], optional): 生成完成序列的最大长度。 如果未提供,将设置为模型的最大序列长度减1。 logprobs (bool, optional): 指示是否计算Token对数概率的标志。默认为False。 echo (bool, optional): 指示是否在生成的输出中包括提示Token的标志。默认为False。
Returns: List[CompletionPrediction]: 完成预测的列表,每个预测包含生成的文本完成。
Note: 此方法为提供的提示词生成文本补完,并使用核心采样引入控制随机性。 如果logprobs被设置为True,则为每个生成的Token计算对数概率。
""" if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=echo, ) if logprobs: return [ { "generation": self.tokenizer.decode(t), "tokens": [self.tokenizer.decode(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i in zip(generation_tokens, generation_logprobs) ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
def chat_completion( self, dialogs: List[Dialog], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, ) -> List[ChatPrediction]: """ 使用语言生成模型,对一个交谈对话的列表生成assistant回复。
Args: dialogs (List[Dialog]): 会话对话的列表,其中每个对话都是消息列表。 temperature (float, optional): 控制采样中随机性的温度值。默认为0.6。 top_p (float, optional): 核心采样的top-p概率阈值。默认为0.9。 max_gen_len (Optional[int], optional): 生成响应序列的最大长度。如果未提供,将设置为模型的最大序列长度减1。 logprobs (bool, optional): 指示是否计算Token对数概率的标志。默认为False。 Returns: List[ChatPrediction]: 聊天预测列表,每个预测包含assistant生成的响应。
Raises: AssertionError: 如果对话中的最后一条消息不是来自用户。 AssertionError: 如果对话角色不按照所需的'user'、'assistant'和可选的'system'顺序。
Note: 此方法为提供的会话对话生成assistant的响应。 它使用核心采样引入文本生成中的控制随机性。 如果logprobs设置为True,则将为每个生成的Token计算对数概率。
""" if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [] unsafe_requests = [] for dialog in dialogs: unsafe_requests.append( any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) ) if dialog[0]["role"] == "system": dialog = [ { "role": dialog[1]["role"], "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], } ] + dialog[2:] assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( [msg["role"] == "assistant" for msg in dialog[1::2]] ), ( "model only supports 'system', 'user' and 'assistant' roles, " "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" ) dialog_tokens: List[int] = sum( [ self.tokenizer.encode( f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", bos=True, eos=True, ) for prompt, answer in zip( dialog[::2], dialog[1::2], ) ], [], ) assert ( dialog[-1]["role"] == "user" ), f"Last message must be from user, got {dialog[-1]['role']}" dialog_tokens += self.tokenizer.encode( f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", bos=True, eos=False, ) prompt_tokens.append(dialog_tokens)
generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, ) if logprobs: return [ { "generation": { "role": "assistant", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, }, "tokens": [self.tokenizer.decode(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i, unsafe in zip( generation_tokens, generation_logprobs, unsafe_requests ) ] return [ { "generation": { "role": "assistant", "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, } } for t, unsafe in zip(generation_tokens, unsafe_requests) ]
|