fix: make eos_token/pad_token overridable and add pickle support
This commit is contained in:
parent
720763ede4
commit
4c846d7114
@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
||||
|
||||
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
||||
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
||||
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
||||
self.pad_token = self.decoder[self.tokenizer.eot_token]
|
||||
# Provide default `eos_token` and `pad_token`
|
||||
if self.eos_token is None:
|
||||
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
||||
if self.pad_token is None:
|
||||
self.pad_token = self.decoder[self.tokenizer.pad_token]
|
||||
|
||||
# Expose for convenience
|
||||
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
||||
self.special_tokens = self.tokenizer._special_tokens
|
||||
@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
||||
def __len__(self):
|
||||
return self.tokenizer.n_vocab
|
||||
|
||||
def __getstate__(self):
|
||||
# Required for `pickle` support
|
||||
state = self.__dict__.copy()
|
||||
del state["tokenizer"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_vocab
|
||||
@ -273,4 +287,4 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
||||
token_ids = [token_ids]
|
||||
if skip_special_tokens:
|
||||
token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
|
||||
return self.tokenizer.decode(token_ids)
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user