001/*-
002 * Copyright 2016 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.ArrayList;
013import java.util.Arrays;
014import java.util.List;
015
016public final class BroadcastUtils {
017
018        /**
019         * Calculate shapes for broadcasting
020         * @param oldShape old shape
021         * @param size dataset size
022         * @param newShape new shape
023         * @return broadcasted shape and full new shape or null if it cannot be done
024         */
025        public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int... newShape) {
026                if (newShape == null) {
027                        return null;
028                }
029        
030                int brank = newShape.length;
031                if (brank == 0) {
032                        if (size == 1) {
033                                return new int[][] {oldShape, newShape};
034                        }
035                        return null;
036                }
037        
038                if (Arrays.equals(oldShape, newShape)) {
039                        return new int[][] {oldShape, newShape};
040                }
041
042                if (ShapeUtils.calcSize(oldShape) != size) {
043                        throw new IllegalArgumentException("Size must match old shape");
044                }
045
046                int offset = brank - oldShape.length;
047                if (offset < 0) { // when new shape is incomplete
048                        newShape = padShape(newShape, -offset);
049                        offset = 0;
050                }
051
052                int[] bshape = padShape(oldShape, offset); // new shape has extra dimensions
053
054                for (int i = 0; i < brank; i++) {
055                        if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) {
056                                return null;
057                        }
058                }
059        
060                return new int[][] {bshape, newShape};
061        }
062
063        /**
064         * Pad shape by prefixing with ones
065         * @param shape to pad
066         * @param padding number of dimensions to add
067         * @return new shape or old shape if padding is zero
068         */
069        public static int[] padShape(final int[] shape, final int padding) {
070                if (padding < 0) {
071                        throw new IllegalArgumentException("Padding must be zero or greater");
072                }
073        
074                if (padding == 0) {
075                        return shape;
076                }
077
078                final int[] nshape = new int[shape.length + padding];
079                Arrays.fill(nshape, 1);
080                System.arraycopy(shape, 0, nshape, padding, shape.length);
081                return nshape;
082        }
083
084        /**
085         * Take in shapes and broadcast them to same rank
086         * @param shapes null shapes are ignored and passed through
087         * @return list of broadcasted shapes plus the first entry is the maximum shape
088         */
089        public static List<int[]> broadcastShapes(int[]... shapes) {
090                int maxRank = -1;
091                for (int[] s : shapes) {
092                        if (s == null) {
093                                continue;
094                        }
095        
096                        int r = s.length;
097                        if (r > maxRank) {
098                                maxRank = r;
099                        }
100                }
101        
102                List<int[]> newShapes = new ArrayList<int[]>();
103                if (maxRank < 0) {
104                        for (int i = 0; i <= shapes.length; i++) { // note the extra null
105                                newShapes.add(null);
106                        }
107                        return newShapes;
108                }
109
110                for (int[] s : shapes) {
111                        newShapes.add(s == null ? null : padShape(s, maxRank - s.length));
112                }
113
114                int[] maxShape = new int[maxRank];
115                for (int i = 0; i < maxRank; i++) {
116                        int m = -1;
117                        for (int[] s : newShapes) {
118                                if (s == null) {
119                                        continue;
120                                }
121                                int l = s[i];
122                                if (l > m) {
123                                        if (m > 1) {
124                                                throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
125                                        }
126                                        m = l;
127                                }
128                        }
129                        maxShape[i] = m;
130                }
131
132                checkShapes(maxShape, newShapes);
133                newShapes.add(0, maxShape);
134                return newShapes;
135        }
136
137        /**
138         * Take in shapes and broadcast them to maximum shape
139         * @param maxShape maximum shape
140         * @param shapes inputs
141         * @return list of broadcasted shapes
142         */
143        public static List<int[]> broadcastShapesToMax(int[] maxShape, int[]... shapes) {
144                int maxRank = maxShape == null ? -1 : maxShape.length;
145                for (int[] s : shapes) {
146                        if (s == null) {
147                                continue;
148                        }
149        
150                        int r = s.length;
151                        if (r > maxRank) {
152                                throw new IllegalArgumentException("A shape exceeds given rank of maximum shape");
153                        }
154                }
155        
156                List<int[]> newShapes = new ArrayList<int[]>();
157                for (int[] s : shapes) {
158                        newShapes.add(s == null ? null : padShape(s, maxRank - s.length));
159                }
160
161                if (maxShape != null ) {
162                        checkShapes(maxShape, newShapes);
163                }
164                return newShapes;
165        }
166
167        private static void checkShapes(int[] maxShape, List<int[]> newShapes) {
168                for (int i = 0; i < maxShape.length; i++) {
169                        int m = maxShape[i];
170                        for (int[] s : newShapes) {
171                                if (s == null) {
172                                        continue;
173                                }
174                                int l = s[i];
175                                if (l != 1 && l != m) {
176                                        throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
177                                }
178                        }
179                }
180        }
181
182        static Dataset createDataset(final Dataset a, final Dataset b, final int[] shape) {
183                final Class<? extends Dataset> rc;
184                final int ar = a.getRank();
185                final int br = b.getRank();
186                Class<? extends Dataset> tc = InterfaceUtils.getBestInterface(a.getClass(), b.getClass());
187                if (ar == 0 ^ br == 0) { // ignore type of zero-rank dataset unless it's floating point
188                        if (ar == 0) {
189                                rc = a.hasFloatingPointElements() ? tc : b.getClass();
190                        } else {
191                                rc = b.hasFloatingPointElements() ? tc : a.getClass();
192                        }
193                } else {
194                        rc = tc;
195                }
196                final int ia = a.getElementsPerItem();
197                final int ib = b.getElementsPerItem();
198        
199                return DatasetFactory.zeros(ia > ib ? ia : ib, rc, shape);
200        }
201
202        /**
203         * Check if dataset item sizes are compatible
204         * <p>
205         * Dataset a is considered compatible with the output dataset if any of the
206         * conditions are true:
207         * <ul>
208         * <li>o is undefined</li>
209         * <li>a has item size equal to o's</li>
210         * <li>a has item size equal to 1</li>
211         * <li>o has item size equal to 1</li>
212         * </ul>
213         * @param a input dataset a
214         * @param o output dataset (can be null)
215         */
216        static void checkItemSize(Dataset a, Dataset o) {
217                final int isa = a.getElementsPerItem();
218                if (o != null) {
219                        final int iso = o.getElementsPerItem();
220                        if (isa != iso && isa != 1 && iso != 1) {
221                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
222                        }
223                }
224        }
225
226        /**
227         * Check if dataset item sizes are compatible
228         * <p>
229         * Dataset a is considered compatible with the output dataset if any of the
230         * conditions are true:
231         * <ul>
232         * <li>a has item size equal to b's</li>
233         * <li>a has item size equal to 1</li>
234         * <li>b has item size equal to 1</li>
235         * <li>a or b are single-valued</li>
236         * </ul>
237         * and, o is undefined, or any of the following are true:
238         * <ul>
239         * <li>o has item size equal to maximum of a and b's</li>
240         * <li>o has item size equal to 1</li>
241         * <li>a and b have item sizes of 1</li>
242         * </ul>
243         * @param a input dataset a
244         * @param b input dataset b
245         * @param o output dataset
246         */
247        static void checkItemSize(Dataset a, Dataset b, Dataset o) {
248                final int isa = a.getElementsPerItem();
249                final int isb = b.getElementsPerItem();
250                if (isa != isb && isa != 1 && isb != 1) {
251                        // exempt single-value dataset case too
252                        if ((isa == 1 || b.getSize() != 1) && (isb == 1 || a.getSize() != 1) ) {
253                                throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another");
254                        }
255                }
256                if (o != null && BooleanDataset.class.isAssignableFrom(o.getClass())) {
257                        final int ism = Math.max(isa, isb);
258                        final int iso = o.getElementsPerItem();
259                        if (iso != ism && iso != 1 && ism != 1) {
260                                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
261                        }
262                }
263        }
264
265        /**
266         * Create a stride array from a dataset to a broadcast shape
267         * @param a dataset
268         * @param broadcastShape shape to broadcast
269         * @return stride array
270         */
271        public static int[] createBroadcastStrides(Dataset a, final int[] broadcastShape) {
272                return createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape);
273        }
274
275        /**
276         * Create a stride array from a dataset to a broadcast shape
277         * @param isize item size
278         * @param oShape original shape
279         * @param oStride original stride
280         * @param broadcastShape shape to broadcast
281         * @return stride array
282         */
283        public static int[] createBroadcastStrides(final int isize, final int[] oShape, final int[] oStride, final int[] broadcastShape) {
284                if (oShape == null) {
285                        if (broadcastShape == null) {
286                                return null;
287                        }
288                        throw new IllegalArgumentException("Broadcast shape must be null if original shape is null");
289                }
290                int rank = oShape.length;
291                if (broadcastShape.length != rank) {
292                        throw new IllegalArgumentException("Dataset must have same rank as broadcast shape");
293                }
294        
295                int[] stride = new int[rank];
296                if (oStride == null) {
297                        int s = isize;
298                        for (int j = rank - 1; j >= 0; j--) {
299                                if (broadcastShape[j] == oShape[j]) {
300                                        stride[j] = s;
301                                        s *= oShape[j];
302                                } else {
303                                        stride[j] = 0;
304                                }
305                        }
306                } else {
307                        for (int j = 0; j < rank; j++) {
308                                if (broadcastShape[j] == oShape[j]) {
309                                        stride[j] = oStride[j];
310                                } else {
311                                        stride[j] = 0;
312                                }
313                        }
314                }
315        
316                return stride;
317        }
318
319        /**
320         * Converts and broadcast all objects as datasets of same shape
321         * @param objects to convert and broadcast
322         * @return all as broadcasted to same shape
323         */
324        public static Dataset[] convertAndBroadcast(Object... objects) {
325                final int n = objects.length;
326
327                Dataset[] datasets = new Dataset[n];
328                int[][] shapes = new int[n][];
329                for (int i = 0; i < n; i++) {
330                        Dataset d = DatasetFactory.createFromObject(objects[i]);
331                        datasets[i] = d;
332                        shapes[i] = d.getShapeRef();
333                }
334
335                List<int[]> nShapes = broadcastShapes(shapes);
336                int[] mshape = nShapes.get(0);
337                for (int i = 0; i < n; i++) {
338                        datasets[i] = datasets[i].getBroadcastView(mshape);
339                }
340
341                return datasets;
342        }
343}