Update files from the datasets library (from 1.16.0)

Release notes: https://github.com/huggingface/datasets/releases/tag/1.16.0
This commit is contained in:
system 2022-01-25 16:50:34 +01:00
parent 0c5d6316a3
commit 9a403d6ee7
2 changed files with 17 additions and 28 deletions

@ -1,4 +1,5 @@
---
pretty_name: IMDB
languages:
- en
paperswithcode_id: imdb-movie-reviews

44
imdb.py

@ -16,9 +16,6 @@
# Lint as: python3
"""IMDB movie reviews dataset."""
import os
import datasets
from datasets.tasks import TextClassification
@ -82,42 +79,33 @@ class Imdb(datasets.GeneratorBasedBuilder):
task_templates=[TextClassification(text_column="text", label_column="label")],
)
def _vocab_text_gen(self, archive):
for _, ex in self._generate_examples(archive, os.path.join("aclImdb", "train")):
yield ex["text"]
def _split_generators(self, dl_manager):
arch_path = dl_manager.download_and_extract(_DOWNLOAD_URL)
data_dir = os.path.join(arch_path, "aclImdb")
archive = dl_manager.download(_DOWNLOAD_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"directory": os.path.join(data_dir, "train")}
name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"directory": os.path.join(data_dir, "test")}
name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"}
),
datasets.SplitGenerator(
name=datasets.Split("unsupervised"),
gen_kwargs={"directory": os.path.join(data_dir, "train"), "labeled": False},
gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train", "labeled": False},
),
]
def _generate_examples(self, directory, labeled=True):
"""Generate IMDB examples."""
def _generate_examples(self, files, split, labeled=True):
"""Generate aclImdb examples."""
# For labeled examples, extract the label from the path.
if labeled:
files = {
"pos": sorted(os.listdir(os.path.join(directory, "pos"))),
"neg": sorted(os.listdir(os.path.join(directory, "neg"))),
}
for key in files:
for id_, file in enumerate(files[key]):
filepath = os.path.join(directory, key, file)
with open(filepath, encoding="UTF-8") as f:
yield key + "_" + str(id_), {"text": f.read(), "label": key}
label_mapping = {"pos": 1, "neg": 0}
for path, f in files:
if path.startswith(f"aclImdb/{split}"):
label = label_mapping.get(path.split("/")[2])
if label is not None:
yield path, {"text": f.read().decode("utf-8"), "label": label}
else:
unsup_files = sorted(os.listdir(os.path.join(directory, "unsup")))
for id_, file in enumerate(unsup_files):
filepath = os.path.join(directory, "unsup", file)
with open(filepath, encoding="UTF-8") as f:
yield id_, {"text": f.read(), "label": -1}
for path, f in files:
if path.startswith(f"aclImdb/{split}"):
if path.split("/")[2] == "unsup":
yield path, {"text": f.read().decode("utf-8"), "label": -1}