fix: make eos_token/pad_token overridable and add pickle support
This commit is contained in:
parent
720763ede4
commit
4c846d7114
@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
||||||
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
||||||
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
# Provide default `eos_token` and `pad_token`
|
||||||
self.pad_token = self.decoder[self.tokenizer.eot_token]
|
if self.eos_token is None:
|
||||||
|
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
||||||
|
if self.pad_token is None:
|
||||||
|
self.pad_token = self.decoder[self.tokenizer.pad_token]
|
||||||
|
|
||||||
# Expose for convenience
|
# Expose for convenience
|
||||||
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
||||||
self.special_tokens = self.tokenizer._special_tokens
|
self.special_tokens = self.tokenizer._special_tokens
|
||||||
@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.tokenizer.n_vocab
|
return self.tokenizer.n_vocab
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# Required for `pickle` support
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state["tokenizer"]
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.tokenizer.n_vocab
|
return self.tokenizer.n_vocab
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user