diff --git a/README.md b/README.md index 8fad649..afe4778 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ --- +pretty_name: IMDB languages: - en paperswithcode_id: imdb-movie-reviews diff --git a/imdb.py b/imdb.py index 8236ba1..49a4b6e 100644 --- a/imdb.py +++ b/imdb.py @@ -16,9 +16,6 @@ # Lint as: python3 """IMDB movie reviews dataset.""" - -import os - import datasets from datasets.tasks import TextClassification @@ -82,42 +79,33 @@ class Imdb(datasets.GeneratorBasedBuilder): task_templates=[TextClassification(text_column="text", label_column="label")], ) - def _vocab_text_gen(self, archive): - for _, ex in self._generate_examples(archive, os.path.join("aclImdb", "train")): - yield ex["text"] - def _split_generators(self, dl_manager): - arch_path = dl_manager.download_and_extract(_DOWNLOAD_URL) - data_dir = os.path.join(arch_path, "aclImdb") + archive = dl_manager.download(_DOWNLOAD_URL) return [ datasets.SplitGenerator( - name=datasets.Split.TRAIN, gen_kwargs={"directory": os.path.join(data_dir, "train")} + name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"} ), datasets.SplitGenerator( - name=datasets.Split.TEST, gen_kwargs={"directory": os.path.join(data_dir, "test")} + name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"} ), datasets.SplitGenerator( name=datasets.Split("unsupervised"), - gen_kwargs={"directory": os.path.join(data_dir, "train"), "labeled": False}, + gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train", "labeled": False}, ), ] - def _generate_examples(self, directory, labeled=True): - """Generate IMDB examples.""" + def _generate_examples(self, files, split, labeled=True): + """Generate aclImdb examples.""" # For labeled examples, extract the label from the path. if labeled: - files = { - "pos": sorted(os.listdir(os.path.join(directory, "pos"))), - "neg": sorted(os.listdir(os.path.join(directory, "neg"))), - } - for key in files: - for id_, file in enumerate(files[key]): - filepath = os.path.join(directory, key, file) - with open(filepath, encoding="UTF-8") as f: - yield key + "_" + str(id_), {"text": f.read(), "label": key} + label_mapping = {"pos": 1, "neg": 0} + for path, f in files: + if path.startswith(f"aclImdb/{split}"): + label = label_mapping.get(path.split("/")[2]) + if label is not None: + yield path, {"text": f.read().decode("utf-8"), "label": label} else: - unsup_files = sorted(os.listdir(os.path.join(directory, "unsup"))) - for id_, file in enumerate(unsup_files): - filepath = os.path.join(directory, "unsup", file) - with open(filepath, encoding="UTF-8") as f: - yield id_, {"text": f.read(), "label": -1} + for path, f in files: + if path.startswith(f"aclImdb/{split}"): + if path.split("/")[2] == "unsup": + yield path, {"text": f.read().decode("utf-8"), "label": -1}