Source code for cblearn.datasets._food_similarity
from pathlib import Path
import logging
import joblib
import os
from typing import Optional, Union
import zipfile
import ssl
import numpy as np
from sklearn.datasets import _base
from sklearn.utils import check_random_state, Bunch
ARCHIVE = _base.RemoteFileMetadata(
filename='food100-dataset.zip',
url='http://vision.cornell.edu/se3/wp-content/uploads/2014/09/food100-dataset.zip',
checksum=('18f5e210174dfdbf6a7b4ed7538cf8ba53fd65e0cbe193519231b8ab4ea8fc62'))
logger = logging.getLogger(__name__)
[docs]
def fetch_food_similarity(data_home: Optional[os.PathLike] = None, download_if_missing: bool = True,
shuffle: bool = True, random_state: Optional[np.random.RandomState] = None,
return_triplets: bool = False) -> Union[Bunch, np.ndarray]:
""" Load the Food-100 food similarity dataset (triplets).
.. warning::
This function downloads the file without verifying the ssl signature to circumvent an outdated certificate of the dataset hosts.
However, after downloading the function verifies the file checksum before loading the file to minimize the risk of man-in-the-middle attacks.
=================== =====================
Triplets 190376
Objects 100
Dimensionality unknown
=================== =====================
See :ref:`food_similarity_dataset` for a detailed description.
Args:
data_home : optional, default: None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
download_if_missing : optional, default=True
shuffle: default = True
Shuffle the order of triplet constraints.
random_state: optional, default = None
Initialization for shuffle random generator
return_triplets : boolean, default=False.
If True, returns numpy array instead of a Bunch object.
Returns:
dataset : :class:`~sklearn.utils.Bunch`
Dictionary-like object, with the following attributes.
data: ndarray, shape (n_triplets, 3)
Each row corresponding a triplet constraint.
The columns represent the target, more similar and more distant food index.
image_names : ndarray, shape (n_objects,)
The food image names corresponding to the indices.
DESCR : string
Description of the dataset.
triplets : numpy array (n_triplets, 3)
Only present when `return_triplets=True`.
Raises:
IOError: If the data is not locally available, but downlaod_if_missing=False
"""
data_home = Path(_base.get_data_home(data_home=data_home))
if not data_home.exists():
data_home.mkdir()
filepath = Path(_base._pkl_filepath(data_home, 'food_similarity.pkz'))
if not filepath.exists():
if not download_if_missing:
raise IOError("Data not found and `download_if_missing` is False")
logger.info('Downloading food similarity from {} to {}'.format(ARCHIVE.url, data_home))
try:
ssl_default = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
archive_path = _base._fetch_remote(ARCHIVE, dirname=data_home)
finally:
ssl._create_default_https_context = ssl_default
with zipfile.ZipFile(archive_path) as zf:
with zf.open('food100-dataset/all-triplets.csv', 'r') as f:
triplets = np.loadtxt(f, dtype=str, delimiter=';')
triplets = np.char.strip(triplets) # trim whitespace
image_names = np.asarray([name[len('food100-dataset/'):]
for name in zf.namelist()
if name.startswith('food100-dataset/images/')
and name.endswith('.jpg')])
joblib.dump((triplets, image_names), filepath, compress=6)
os.remove(archive_path)
else:
triplets, image_names = joblib.load(filepath)
image_names = np.sort(image_names)
triplets = np.searchsorted(image_names, triplets)
if shuffle:
random_state = check_random_state(random_state)
triplets = random_state.permutation(triplets)
module_path = Path(__file__).parent
with module_path.joinpath('descr', 'food_similarity.rst').open() as rst_file:
fdescr = rst_file.read()
if return_triplets:
return triplets
return Bunch(data=triplets,
image_names=image_names,
DESCR=fdescr)