Source code for nums.core.array.utils

# Copyright (C) 2020 NumS Development Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import itertools
from typing import Tuple, Iterator

import numpy as np
import scipy.special

from nums.core.settings import np_ufunc_map
from nums.core.array.errors import AxisError

# pylint: disable = no-member, trailing-whitespace


[docs]def to_dtype_cls(dtype): if hasattr(dtype, "__name__"): return dtype return np.__getattribute__(str(dtype))
[docs]def get_uop_output_type(op_name, dtype): a = np.array(1, dtype=dtype) result_dtype = np.__getattribute__(op_name)(a).dtype return to_dtype_cls(result_dtype)
[docs]def get_bop_output_type(op_name, dtype_a, dtype_b): a = np.array(1, dtype=dtype_a) b = np.array(2, dtype=dtype_b) op_name = np_ufunc_map.get(op_name, op_name) try: dtype = np.__getattribute__(op_name)(a, b).dtype return to_dtype_cls(dtype) except TypeError as err: raise err except Exception as _: dtype = scipy.special.__getattribute__(op_name)(a, b).dtype return to_dtype_cls(dtype)
[docs]def is_index_subscript(val): return is_int(val) or is_uint(val)
[docs]def is_regular_subscript(val): return isinstance(val, slice) or is_index_subscript(val)
[docs]def is_scalar(val): return is_supported(val)
[docs]def is_supported(val, type_test=False): return ( is_bool(val, type_test) or is_uint(val, type_test) or is_int(val, type_test) or is_float(val, type_test) or is_complex(val, type_test) )
[docs]def is_bool(val, type_test=False): return is_type(type_test, val, (bool, np.bool_))
[docs]def is_uint(val, type_test=False): return is_type(type_test, val, (np.uint, np.uint8, np.uint16, np.uint32, np.uint64))
[docs]def is_int(val, type_test=False): return is_type(type_test, val, (int, np.int8, np.int16, np.int32, np.int64))
[docs]def is_float(val, type_test=False): return is_type(type_test, val, (float, np.float16, np.float32, np.float64))
[docs]def is_complex(val, type_test=False): return is_type(type_test, val, (np.complex64, np.complex128))
[docs]def is_type(type_test, val, types): return val in types if type_test else isinstance(val, types)
[docs]def get_reduce_output_type(op_name, dtype): a = np.array([0, 1], dtype=dtype) dtype = np.__getattribute__(op_name)(a).dtype return np.__getattribute__(str(dtype))
[docs]def shape_from_block_array(arr: np.ndarray): grid_shape = arr.shape num_axes = len(arr.shape) shape = np.zeros(num_axes, dtype=int) for j in range(num_axes): pos = [[0]] * num_axes pos[j] = range(grid_shape[j]) j_iter = list(itertools.product(*pos)) for j_access in j_iter: shape[j] += arr[j_access].shape[j] return tuple(shape)
[docs]def broadcast(a_shape, b_shape): a_view = np.lib.stride_tricks.broadcast_to(0, a_shape) b_view = np.lib.stride_tricks.broadcast_to(0, b_shape) return np.broadcast(a_view, b_view)
[docs]def broadcast_block_shape(a_shape, b_shape, a_block_shape): # Starting from last block shape dim and # map each shape dim to block shape dim as already defined, # and for the rest of dims, set block shape to 1. result_shape = broadcast(a_shape, b_shape).shape result_block_shape = [] a_block_shape_r = list(reversed(a_block_shape)) for i, _ in enumerate(reversed(result_shape)): if i < len(a_block_shape_r): result_block_shape.append(a_block_shape_r[i]) else: result_block_shape.append(1) return tuple(reversed(result_block_shape))
[docs]def broadcast_shape(a_shape, b_shape): return broadcast(a_shape, b_shape).shape
[docs]def can_broadcast_shapes(a_shape, b_shape): try: assert broadcast_shape(a_shape, b_shape) is not None return True except ValueError as _: return False
[docs]def broadcastable(a_shape, b_shape, a_block_shape, b_block_shape): if a_shape == b_shape: return a_block_shape == b_block_shape result_shape = broadcast_shape(a_shape, b_shape) if result_shape is None: return False min_bs, max_bs = sorted([a_block_shape, b_block_shape], key=len) for i in range(-1, -len(max_bs) - 1, -1): if -len(min_bs) - 1 < i: if ( a_block_shape[i] != b_block_shape[i] and a_block_shape[i] != 1 and b_block_shape[i] != 1 ): return False return True
[docs]def is_1d(shape): _shape = [i for i in shape if i != 1] return len(_shape) == 1
[docs]def broadcast_shape_to(from_shape, to_shape): # Enforce broadcasting rules from an # array of references to 0 with shape from_shape. from_view = np.lib.stride_tricks.broadcast_to(0, from_shape) return np.lib.stride_tricks.broadcast_to(from_view, to_shape)
[docs]def can_broadcast_shape_to(from_shape, to_shape): # See: https://numpy.org/devdocs/user/theory.broadcasting.html try: broadcast_shape_to(from_shape, to_shape) return True except ValueError as _: return False
[docs]def broadcast_shape_to_alt(from_shape, to_shape): # This is heavily tested with shapes up to length 5. from_num_axes = len(from_shape) to_num_axes = len(to_shape) result_shape = [] if to_num_axes < from_num_axes: raise ValueError( "Input shape has more dimensions than allowed by the axis remapping." ) if to_num_axes == 0 and from_shape != 0: raise ValueError("Cannot broadcast non-scalar shape to scalar shape ().") from_shape_r = list(reversed(from_shape)) to_shape_r = list(reversed(to_shape)) for i, from_dim in enumerate(from_shape_r): to_dim = to_shape_r[i] if from_dim == 1: result_shape.append(to_dim) elif to_dim == from_dim: result_shape.append(to_dim) else: raise ValueError( "Cannot broadcast %s to %s." % (str(from_shape), str(to_shape)) ) return tuple(reversed(result_shape + to_shape_r[from_num_axes:]))
[docs]def is_array_like(obj): return isinstance(obj, (tuple, list, np.ndarray))
[docs]def block_shape_from_subscript(subscript: tuple, block_shape: tuple): new_block_shape = [] for i, obj in enumerate(subscript): if isinstance(obj, slice): new_block_shape.append(block_shape[i]) elif is_regular_subscript(obj): continue else: raise NotImplementedError("No support for advanced indexing.") return tuple(new_block_shape)
[docs]def get_slices(total_size, batch_size, order, reverse_blocks=False): assert order in (-1, 1) if order > 0: if reverse_blocks: result = list(reversed(list(range(total_size, 0, -batch_size)) + [0])) else: result = list(range(0, total_size, batch_size)) + [total_size] return list(map(lambda s: slice(*s, order), zip(*(result[:-1], result[1:])))) else: if reverse_blocks: # If reverse order blocks are not multiples of axis dimension, # then the last block is smaller than block size and should be # the first block. result = list(reversed(list(range(-total_size - 1, -1, batch_size)) + [-1])) else: result = list(range(-1, -total_size - 1, -batch_size)) + [-total_size - 1] return list(map(lambda s: slice(*s, order), zip(*(result[:-1], result[1:]))))
[docs]class OrderedGrid: def __init__( self, shape: Tuple, block_shape: Tuple, order: Tuple, block_order=None ): if block_order is not None: assert len(block_order) == len(shape) self.shape = tuple(shape) self.block_shape = tuple(np.min([shape, block_shape], axis=0)) self.order = tuple(order) self.grid_shape = [] self.grid_slices = [] for i in range(len(self.shape)): dim = self.shape[i] block_dim = block_shape[i] axis_order = order[i] reverse_blocks = False if block_order is not None: reverse_blocks = block_order[i] == -1 axis_slices = get_slices(dim, block_dim, axis_order, reverse_blocks) self.grid_slices.append(axis_slices) self.grid_shape.append(len(axis_slices)) self.grid_shape = tuple(self.grid_shape) # Assumes C-style ordering. # We add len(shape) to allow for axis consisting of the actual slices. self.slices = np.array( list(itertools.product(*self.grid_slices)), dtype=slice ).reshape(tuple(list(self.grid_shape) + [len(shape)]))
[docs] def index_iterator(self) -> Iterator[Tuple]: if 0 in self.shape: return [] return itertools.product(*map(range, self.grid_shape))
[docs]def idx2addr(index: tuple, shape: tuple): strides = [np.product(shape[i:]) for i in range(1, len(shape))] + [1] addr: int = sum(np.array(index) * strides) return addr
[docs]def addr2idx(addr: int, shape: tuple): strides = [np.product(shape[i:]) for i in range(1, len(shape))] + [1] index = [] val = addr for i in range(len(strides)): stride = strides[i] axis_index = int(val / stride) index.append(axis_index) val %= stride return tuple(index)
[docs]def slice_sel_to_index_list(slice_selection: tuple): slice_ranges = [] for slice_or_index in slice_selection: if isinstance(slice_or_index, slice): slice_ranges.append(list(range(slice_or_index.start, slice_or_index.stop))) elif is_regular_subscript(slice_or_index): slice_ranges.append([slice_or_index]) index_list = list(itertools.product(*slice_ranges)) return index_list
[docs]def translate_index_list(from_index_list, from_shape, to_shape): to_index_list = [] for src_index in from_index_list: addr = idx2addr(src_index, from_shape) to_index_list.append(addr2idx(addr, to_shape)) return to_index_list
[docs]def np_tensordot_param_test(as_, nda, bs, ndb, axes): # Error checking before everything gets passed into BlockArray operations. Modified from the # original NumPy tensordot method for error checking: # https://github.com/numpy/numpy/blob/v1.20.0/numpy/core/numeric.py#L949-L1139 try: iter(axes) except Exception: axes_a = list(range(-axes, 0)) axes_b = list(range(0, axes)) else: axes_a, axes_b = axes try: na = len(axes_a) axes_a = list(axes_a) except TypeError: axes_a = [axes_a] na = 1 try: nb = len(axes_b) axes_b = list(axes_b) except TypeError: axes_b = [axes_b] nb = 1 equal = True if na != nb: equal = False else: for k in range(na): if as_[axes_a[k]] != bs[axes_b[k]]: equal = False break if axes_a[k] < 0: axes_a[k] += nda if axes_b[k] < 0: axes_b[k] += ndb if not equal: return True return False
# NumPy's internal axis-checking logic # https://www.kite.com/python/docs/numpy.core.multiarray.normalize_axis_index
[docs]def normalize_axis_index(axis, ndim): """ Parameters ---------- axis : int The un-normalized index of the axis. Can be negative ndim : int The number of dimensions of the array that `axis` should be normalized against Returns ------- normalized_axis : int The normalized axis index, such that `0 <= normalized_axis < ndim` Raises ------ AxisError If the axis index is invalid, when `-ndim <= axis < ndim` is false. Examples -------- >>> normalize_axis_index(0, ndim=3) 0 >>> normalize_axis_index(1, ndim=3) 1 >>> normalize_axis_index(-1, ndim=3) 2 >>> normalize_axis_index(3, ndim=3) Traceback (most recent call last): ... AxisError: axis 3 is out of bounds for array of dimension 3 >>> normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg') Traceback (most recent call last): ... AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3 """ if -ndim > axis >= ndim: raise AxisError( "axis {} is out of bounds for array of dimension {}".format(axis, ndim) ) return axis % ndim