Source code for segments.huggingface

from __future__ import annotations

import json
import logging
import os
import tempfile
from string import Template
from typing import TYPE_CHECKING, Any, Dict, cast

import requests
from PIL import Image
from segments.utils import load_image_from_url, load_label_bitmap_from_url


# https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
if TYPE_CHECKING:
    from segments.typing import Release


#############
# Variables #
#############
logger = logging.getLogger(__name__)
try:
    import datasets
    from huggingface_hub import HfApi
except ImportError:
    logger.error("Please install HuggingFace datasets first: pip install --upgrade datasets")

# Add some functionality to the push_to_hub function of datasets.Dataset
push_to_hub_original = datasets.Dataset.push_to_hub

hf_api = HfApi()


#############
# Functions #
#############
def push_to_hub(self: datasets.Dataset, repo_id: str, *args: Any, **kwargs: Any) -> None:
    push_to_hub_original(self, repo_id, *args, **kwargs)

    # Upload the label file (https://huggingface.co/datasets/huggingface/label-files)
    if hasattr(self, "id2label"):
        # print("Uploading id2label.json")
        tmpfile = os.path.join(tempfile.gettempdir(), "id2label.json")
        with open(tmpfile, "w") as f:
            json.dump(self.id2label, f)

        hf_api.upload_file(
            path_or_fileobj=tmpfile,
            path_in_repo="id2label.json",
            repo_id=repo_id,
            repo_type="dataset",
        )

    # Upload README.md
    if hasattr(self, "readme"):
        # print("Uploading README.md")
        tmpfile = os.path.join(tempfile.gettempdir(), "README.md")
        with open(tmpfile, "w") as f:
            f.write(self.readme)

        hf_api.upload_file(
            path_or_fileobj=tmpfile,
            path_in_repo="README.md",
            repo_id=repo_id,
            repo_type="dataset",
        )


datasets.Dataset.push_to_hub = push_to_hub


def get_taxonomy_table(taxonomy: Dict[str, Any]) -> str:
    markdown_table = ""
    for category in taxonomy["categories"]:
        id_ = category["id"]
        name = category["name"]
        description = category["description"] if "description" in category else "-"
        markdown_table += f"| {id_} | {name} | {description} |\n"
    return markdown_table


[docs]def release2dataset(release: Release, download_images: bool = True) -> datasets.Dataset: """Create a Huggingface dataset from a release. Args: release: A Segments release resulting from :meth:`.get_release`. download_images: If images need to be downloaded from an AWS S3 url. Defaults to :obj:`True`. Returns: A HuggingFace dataset. Raises: :exc:`ValueError`: If the type of dataset is not yet supported. """ # try: # import datasets # except ImportError as e: # logger.error( # "Please install HuggingFace datasets first: pip install --upgrade datasets" # ) # raise e content = requests.get( cast(str, release.attributes.url) # TODO Fix in the backend. ) release_dict = json.loads(content.content) task_type = release_dict["dataset"]["task_type"] if task_type in ["vector", "bboxes", "keypoint"]: features = datasets.Features( { "name": datasets.Value("string"), "uuid": datasets.Value("string"), "image": {"url": datasets.Value("string")}, "status": datasets.Value("string"), "label": { "annotations": [ { "id": datasets.Value("int32"), "category_id": datasets.Value("int32"), "type": datasets.Value("string"), "points": [[datasets.Value("float32")]], } ], }, } ) elif task_type in ["segmentation-bitmap", "segmentation-bitmap-highres"]: features = datasets.Features( { "name": datasets.Value("string"), "uuid": datasets.Value("string"), "image": {"url": datasets.Value("string")}, "status": datasets.Value("string"), "label": { "annotations": [ { "id": datasets.Value("int32"), "category_id": datasets.Value("int32"), } ], "segmentation_bitmap": {"url": datasets.Value("string")}, }, } ) else: raise ValueError("This type of dataset is not yet supported.") samples = release_dict["dataset"]["samples"] data_rows = [] for sample in samples: try: del sample["labels"]["ground-truth"]["attributes"]["format_version"] except (KeyError, TypeError): pass data_row: Dict[str, Any] = {} # Name data_row["name"] = sample["name"] # Uuid data_row["uuid"] = sample["uuid"] # Status try: data_row["status"] = sample["labels"]["ground-truth"]["label_status"] except (KeyError, TypeError): data_row["status"] = "UNLABELED" # Image if task_type in [ "vector", "bboxes", "keypoint", "segmentation-bitmap", "segmentation-bitmap-highres", ]: try: data_row["image"] = sample["attributes"]["image"] except (KeyError, TypeError): data_row["image"] = {"url": None} # Label try: label = sample["labels"]["ground-truth"]["attributes"] # Remove the image-level attributes if "attributes" in label: del label["attributes"] # Remove the object-level attributes for annotation in label["annotations"]: if "attributes" in annotation: del annotation["attributes"] data_row["label"] = label except (KeyError, TypeError): error_label: Dict[str, Any] = {"annotations": []} if task_type in ["segmentation-bitmap", "segmentation-bitmap-highres"]: error_label["segmentation_bitmap"] = {"url": None} data_row["label"] = error_label data_rows.append(data_row) # Now transform to column format dataset_dict: Dict[str, Any] = {key: [] for key in features.keys()} for data_row in data_rows: for key in dataset_dict.keys(): dataset_dict[key].append(data_row[key]) # Create the HF Dataset and flatten it dataset = datasets.Dataset.from_dict(dataset_dict, features, split="train") dataset = dataset.flatten() # Optionally download the images if ( task_type in [ "vector", "bboxes", "keypoint", "segmentation-bitmap", "segmentation-bitmap-highres", ] and download_images ): def download_image(data_row: Dict[str, Any]) -> Dict[str, Any]: try: data_row["image"] = load_image_from_url(data_row["image.url"]) except Exception: data_row["image"] = None return data_row def download_segmentation_bitmap(data_row: Dict[str, Any]) -> Dict[str, Any]: try: segmentation_bitmap = load_label_bitmap_from_url(data_row["label.segmentation_bitmap.url"]) data_row["label.segmentation_bitmap"] = Image.fromarray(segmentation_bitmap) except Exception: data_row["label.segmentation_bitmap"] = Image.new("RGB", (1, 1)) # TODO: replace with None return data_row dataset = dataset.map(download_image, remove_columns=["image.url"]) if task_type in ["segmentation-bitmap", "segmentation-bitmap-highres"]: dataset = dataset.map( download_segmentation_bitmap, remove_columns=["label.segmentation_bitmap.url"], ) # Reorder the features features = datasets.Features( { "name": dataset.features["name"], "uuid": dataset.features["uuid"], "status": dataset.features["status"], "image": datasets.Image(), "label.annotations": dataset.features["label.annotations"], "label.segmentation_bitmap": datasets.Image(), } ) dataset.info.features = features else: # Reorder the features features = datasets.Features( { "name": dataset.features["name"], "uuid": dataset.features["uuid"], "status": dataset.features["status"], "image": datasets.Image(), "label.annotations": dataset.features["label.annotations"], } ) dataset.info.features = features # Create id2label id2label = {} for category in release_dict["dataset"]["task_attributes"]["categories"]: id2label[category["id"]] = category["name"] id2label[0] = "unlabeled" dataset.id2label = id2label # Create readme.md and update DatasetInfo # https://stackoverflow.com/questions/6385686/is-there-a-native-templating-system-for-plain-text-files-in-python task_type = release_dict["dataset"]["task_type"] if task_type in ["segmentation-bitmap", "segmentation-bitmap-highres"]: task_category = "image-segmentation" elif task_type in ["vector", "bboxes"]: task_category = "object-detection" else: task_category = "other" info = { "name": release_dict["dataset"]["name"], "segments_url": f'https://segments.ai/{release_dict["dataset"]["owner"]}/{release_dict["dataset"]["name"]}', "short_description": release_dict["dataset"]["description"], "release": release_dict["name"], "taxonomy_table": get_taxonomy_table(release_dict["dataset"]["task_attributes"]), "task_category": task_category, } # Create readme.md with open(os.path.join(os.path.dirname(__file__), "data", "dataset_card_template.md"), "r") as f: template = Template(f.read()) readme = template.substitute(info) dataset.readme = readme # Update DatasetInfo dataset.info.description = info["short_description"] dataset.info.homepage = info["segments_url"] return dataset