Example for AutoModelForCausalLM (#11)
- Example for AutoModelForCausalLM (7e03f75b800b1978e2f57e01ee8e99d30c41e47e) Co-authored-by: Pedro Cuenca <pcuenq@users.noreply.huggingface.co>
This commit is contained in:
parent
2b72492696
commit
d3aa29f914
48
README.md
48
README.md
@ -273,7 +273,9 @@ This repository contains two versions of Meta-Llama-3-8B-Instruct, for use with
|
|||||||
|
|
||||||
### Use with transformers
|
### Use with transformers
|
||||||
|
|
||||||
See the snippet below for usage with Transformers:
|
You can run conversational inference using the Transformers pipeline abstraction, or by leveraging the Auto classes with the `generate()` function. Let's see examples of both.
|
||||||
|
|
||||||
|
#### Transformers pipeline
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import transformers
|
import transformers
|
||||||
@ -315,6 +317,50 @@ outputs = pipeline(
|
|||||||
print(outputs[0]["generated_text"][len(prompt):])
|
print(outputs[0]["generated_text"][len(prompt):])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Transformers AutoModelForCausalLM
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
|
||||||
|
{"role": "user", "content": "Who are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
terminators = [
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
max_new_tokens=256,
|
||||||
|
eos_token_id=terminators,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=0.9,
|
||||||
|
)
|
||||||
|
response = outputs[0][input_ids.shape[-1]:]
|
||||||
|
print(tokenizer.decode(response, skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Use with `llama3`
|
### Use with `llama3`
|
||||||
|
|
||||||
Please, follow the instructions in the [repository](https://github.com/meta-llama/llama3)
|
Please, follow the instructions in the [repository](https://github.com/meta-llama/llama3)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user