diff --git a/tokenization_arcade100k.py b/tokenization_arcade100k.py index be91425..ddbfa46 100644 --- a/tokenization_arcade100k.py +++ b/tokenization_arcade100k.py @@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer): self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()} self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()}) - self.eos_token = self.decoder[self.tokenizer.eot_token] - self.pad_token = self.decoder[self.tokenizer.eot_token] + # Provide default `eos_token` and `pad_token` + if self.eos_token is None: + self.eos_token = self.decoder[self.tokenizer.eot_token] + if self.pad_token is None: + self.pad_token = self.decoder[self.tokenizer.pad_token] + # Expose for convenience self.mergeable_ranks = self.tokenizer._mergeable_ranks self.special_tokens = self.tokenizer._special_tokens @@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer): def __len__(self): return self.tokenizer.n_vocab + def __getstate__(self): + # Required for `pickle` support + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = tiktoken.Encoding(**self._tiktoken_config) + @property def vocab_size(self): return self.tokenizer.n_vocab @@ -273,4 +287,4 @@ class Arcade100kTokenizer(PreTrainedTokenizer): token_ids = [token_ids] if skip_special_tokens: token_ids = [i for i in token_ids if i < self.tokenizer.eot_token] - return self.tokenizer.decode(token_ids) \ No newline at end of file + return self.tokenizer.decode(token_ids)