fix: make eos_token/pad_token overridable and add pickle support

This commit is contained in:
jon-tow 2024-01-22 23:45:51 -05:00
parent 720763ede4
commit 4c846d7114

@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()} self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()}) self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
# Provide default `eos_token` and `pad_token`
if self.eos_token is None:
self.eos_token = self.decoder[self.tokenizer.eot_token] self.eos_token = self.decoder[self.tokenizer.eot_token]
self.pad_token = self.decoder[self.tokenizer.eot_token] if self.pad_token is None:
self.pad_token = self.decoder[self.tokenizer.pad_token]
# Expose for convenience # Expose for convenience
self.mergeable_ranks = self.tokenizer._mergeable_ranks self.mergeable_ranks = self.tokenizer._mergeable_ranks
self.special_tokens = self.tokenizer._special_tokens self.special_tokens = self.tokenizer._special_tokens
@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
def __len__(self): def __len__(self):
return self.tokenizer.n_vocab return self.tokenizer.n_vocab
def __getstate__(self):
# Required for `pickle` support
state = self.__dict__.copy()
del state["tokenizer"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
@property @property
def vocab_size(self): def vocab_size(self):
return self.tokenizer.n_vocab return self.tokenizer.n_vocab