From d3aa29f914761e8ea0298051fbaf8dd173e94db5 Mon Sep 17 00:00:00 2001 From: Philipp Schmid Date: Fri, 19 Apr 2024 06:56:50 +0000 Subject: [PATCH] Example for AutoModelForCausalLM (#11) - Example for AutoModelForCausalLM (7e03f75b800b1978e2f57e01ee8e99d30c41e47e) Co-authored-by: Pedro Cuenca --- README.md | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5bc9c21..3d58fbf 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,9 @@ This repository contains two versions of Meta-Llama-3-8B-Instruct, for use with ### Use with transformers -See the snippet below for usage with Transformers: +You can run conversational inference using the Transformers pipeline abstraction, or by leveraging the Auto classes with the `generate()` function. Let's see examples of both. + +#### Transformers pipeline ```python import transformers @@ -315,6 +317,50 @@ outputs = pipeline( print(outputs[0]["generated_text"][len(prompt):]) ``` +#### Transformers AutoModelForCausalLM + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", +) + +messages = [ + {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, + {"role": "user", "content": "Who are you?"}, +] + +input_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt" +).to(model.device) + +terminators = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +outputs = model.generate( + input_ids, + max_new_tokens=256, + eos_token_id=terminators, + do_sample=True, + temperature=0.6, + top_p=0.9, +) +response = outputs[0][input_ids.shape[-1]:] +print(tokenizer.decode(response, skip_special_tokens=True)) +``` + + ### Use with `llama3` Please, follow the instructions in the [repository](https://github.com/meta-llama/llama3)