Source code for cblearn.datasets._similarity_matrix

import logging
import os
import zipfile
from pathlib import Path
from typing import Optional, Union

import joblib
import numpy as np
import scipy.io
from sklearn.datasets import _base
from sklearn.utils import Bunch

ARCHIVE = _base.RemoteFileMetadata(
    filename='all.zip',
    url='https://files.osf.io/v1/resources/ey9vp/providers/osfstorage/'
        '5e7a7065d2927f006fdd1cf9?action=download&direct&version=1',
    checksum=('8c799cdebb00192ecb63f3e28c6eeee0e2f64fcb8dad3bc68982e551f2ae5b1c'))

logger = logging.getLogger(__name__)

AVAILABLE_SIMILARITIES = [
    'fruit2_romney', 'nonsense_romney', 'furniture_romney', 'kinship_kimrosenberg', 'rectangle_kruschke',
    'vegetables2_romney', 'animalpictures5', 'auditory', 'druguse', 'faces11', 'fruits', 'dotpatterns',
    'furniture2_romney', 'bodies_viken', 'textures', 'sport_romney', 'bankwiring', 'morsenumbers',
    'faces_busey', 'letters', 'vehicles_romney', 'vehicles2_romney', 'birds_romney', 'fruit_romney', 'risks',
    'morseall', 'texturemit_heaps', 'cartoonfaces', 'country_robinsonhefner', 'congress', 'phonemes',
    'toys_romney', 'colour', 'countriessim', 'faces5', 'tools_romney', 'lines_cohen', 'abstractnumbers',
    'countriesdis', 'animalnames11', 'faces_steyvers', 'weapons2_romney', 'texturebrodatz_heaps',
    'fish_romney', 'flowerpots', 'sizeangle_treat', 'clothing2_romney', 'weapons_romney', 'clothing_romney',
    'animalnames5', 'vegetables_romney', 'animalpictures11']


[docs] def fetch_similarity_matrix(name: str, data_home: Optional[os.PathLike] = None, download_if_missing: bool = True ) -> Union[Bunch, np.ndarray]: """ Load human similarity judgements, aggregated to a similarity matrix. This function provides access to the following similarity matrices: `fruit2_romney, nonsense_romney, furniture_romney, kinship_kimrosenberg, rectangle_kruschke, vegetables2_romney, animalpictures5, auditory, druguse, faces11, fruits, dotpatterns, furniture2_romney, bodies_viken, textures, sport_romney, bankwiring, morsenumbers, faces_busey, letters, vehicles_romney, vehicles2_romney, birds_romney, fruit_romney, risks, morseall, texturemit_heaps, cartoonfaces, country_robinsonhefner, congress, phonemes, toys_romney, colour, countriessim, faces5, tools_romney, lines_cohen, abstractnumbers, countriesdis, animalnames11, faces_steyvers, weapons2_romney, texturebrodatz_heaps, fish_romney, flowerpots, sizeangle_treat, clothing2_romney, weapons_romney, clothing_romney, animalnames5, vegetables_romney, animalpictures11`. See :ref:`similarity_matrix_dataset` for a detailed description. >>> dataset = fetch_similarity_matrix('colour') # doctest: +REMOTE_DATA >>> dataset.labels[:2].tolist() # doctest: +REMOTE_DATA ['434', '445'] >>> dataset.similarity.shape # doctest: +REMOTE_DATA (14, 14) Args: name: Name of the similarity dataset data_home : optional, default: None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. download_if_missing : optional, default=True Returns: dataset : :class:`~sklearn.utils.Bunch` Dictionary-like object, with the following attributes. similarity : ndarray, shape (n_objects, n_objects) Symmetric matrix of normalized object similarities. None for some datasets. proximity : ndarray, shape (n_objects, n_objects) Symmetric matrix of normalized pairwise proximities. None for some datasets. n_objects: int Number of objects labels : (n_objects,) Single word describing each object sigma: float Uncertainty of the similarity values. Not available for all datasets. DESCR : string Description of the dataset. Raises: IOError: If the data is not locally available, but download_if_missing=False """ if name not in AVAILABLE_SIMILARITIES: raise ValueError(f"Unexpected similarity name = {name}. Use one of {AVAILABLE_SIMILARITIES}.") data_home = Path(_base.get_data_home(data_home=data_home)) if not data_home.exists(): data_home.mkdir() basepath = Path(_base._pkl_filepath(data_home, 'similarity_collection/')) filepath = basepath.joinpath(f'{name}.pkz') if not filepath.exists(): if not download_if_missing: raise IOError("Data not found and `download_if_missing` is False") logger.info('Downloading imagenet similarity data from {} to {}'.format(ARCHIVE.url, data_home)) archive_path = _base._fetch_remote(ARCHIVE, dirname=data_home) if not basepath.exists(): basepath.mkdir(parents=True) with zipfile.ZipFile(archive_path) as zf: for _this_name in AVAILABLE_SIMILARITIES: with zf.open(f'{_this_name}.mat', 'r') as f: _raw = scipy.io.loadmat(f) _this_dict = { 'similarity': np.array(_raw.get('s', None)), 'proximity': np.array(_raw.get('d', None)), 'n_objects': int(_raw['n'][0, 0]), 'labels': np.array(_raw['labs'], dtype=str), 'sigma': float(_raw.get('sigma_emp', np.array([[np.nan]]))[0, 0]), } _this_filepath = basepath.joinpath(f'{_this_name}.pkz') joblib.dump(_this_dict, _this_filepath, compress=6) if name == _this_name: data_dict = _this_dict os.remove(archive_path) else: data_dict = joblib.load(filepath) module_path = Path(__file__).parent with module_path.joinpath('descr', 'similarity_matrix.rst').open() as rst_file: fdescr = rst_file.read() return Bunch(**data_dict, DESCR=fdescr)