diff --git a/README.md b/README.md index 843cd9d..7640174 100644 --- a/README.md +++ b/README.md @@ -104,11 +104,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto") + +device = torch.device("cuda:0") +model.cuda() + inputs = tokenizer('''```python def print_prime(n): """ Print all primes between 1 and n - """''', return_tensors="pt", return_attention_mask=False) + """''', return_tensors="pt", return_attention_mask=False).to('cuda') outputs = model.generate(**inputs, max_length=200) text = tokenizer.batch_decode(outputs)[0]