001// License: GPL. For details, see LICENSE file.
002package org.openstreetmap.josm.data.protobuf;
003
004import java.io.BufferedInputStream;
005import java.io.ByteArrayInputStream;
006import java.io.IOException;
007import java.io.InputStream;
008import java.util.ArrayList;
009import java.util.Collection;
010import java.util.List;
011
012import org.openstreetmap.josm.tools.Logging;
013
014/**
015 * A basic Protobuf parser
016 *
017 * @author Taylor Smock
018 * @since 17862
019 */
020public class ProtobufParser implements AutoCloseable {
021    /**
022     * The default byte size (see {@link #VAR_INT_BYTE_SIZE} for var ints)
023     */
024    public static final byte BYTE_SIZE = 8;
025    /**
026     * The byte size for var ints (since the first byte is just an indicator for if the var int is done)
027     */
028    public static final byte VAR_INT_BYTE_SIZE = BYTE_SIZE - 1;
029    /**
030     * Used to get the most significant byte
031     */
032    static final byte MOST_SIGNIFICANT_BYTE = (byte) (1 << 7);
033    /**
034     * Convert a byte array to a number (little endian)
035     *
036     * @param bytes    The bytes to convert
037     * @param byteSize The size of the byte. For var ints, this is 7, for other ints, this is 8.
038     * @return An appropriate {@link Number} class.
039     */
040    public static Number convertByteArray(byte[] bytes, byte byteSize) {
041        long number = 0;
042        for (int i = 0; i < bytes.length; i++) {
043            // Need to convert to uint64 in order to avoid bit operation from filling in 1's and overflow issues
044            number += Byte.toUnsignedLong(bytes[i]) << (byteSize * i);
045        }
046        return convertLong(number);
047    }
048
049    /**
050     * Convert a long to an appropriate {@link Number} class
051     *
052     * @param number The long to convert
053     * @return A {@link Number}
054     */
055    public static Number convertLong(long number) {
056        // TODO deal with booleans
057        if (number <= Byte.MAX_VALUE && number >= Byte.MIN_VALUE) {
058            return (byte) number;
059        } else if (number <= Short.MAX_VALUE && number >= Short.MIN_VALUE) {
060            return (short) number;
061        } else if (number <= Integer.MAX_VALUE && number >= Integer.MIN_VALUE) {
062            return (int) number;
063        }
064        return number;
065    }
066
067    /**
068     * Decode a zig-zag encoded value
069     *
070     * @param signed The value to decode
071     * @return The decoded value
072     */
073    public static Number decodeZigZag(Number signed) {
074        final long value = signed.longValue();
075        return convertLong((value >> 1) ^ -(value & 1));
076    }
077
078    /**
079     * Encode a number to a zig-zag encode value
080     *
081     * @param signed The number to encode
082     * @return The encoded value
083     */
084    public static Number encodeZigZag(Number signed) {
085        final long value = signed.longValue();
086        // This boundary condition could be >= or <= or both. Tests indicate that it doesn't actually matter.
087        // The only difference would be the number type returned, except it is always converted to the most basic type.
088        final int shift = (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE ? Long.BYTES : Integer.BYTES) * 8 - 1;
089        return convertLong((value << 1) ^ (value >> shift));
090    }
091
092    private final InputStream inputStream;
093
094    /**
095     * Create a new parser
096     *
097     * @param bytes The bytes to parse
098     */
099    public ProtobufParser(byte[] bytes) {
100        this(new ByteArrayInputStream(bytes));
101    }
102
103    /**
104     * Create a new parser
105     *
106     * @param inputStream The InputStream (will be fully read at this time)
107     */
108    public ProtobufParser(InputStream inputStream) {
109        if (inputStream.markSupported()) {
110            this.inputStream = inputStream;
111        } else {
112            this.inputStream = new BufferedInputStream(inputStream);
113        }
114    }
115
116    /**
117     * Read all records
118     *
119     * @return A collection of all records
120     * @throws IOException - if an IO error occurs
121     */
122    public Collection<ProtobufRecord> allRecords() throws IOException {
123        Collection<ProtobufRecord> records = new ArrayList<>();
124        while (this.hasNext()) {
125            records.add(new ProtobufRecord(this));
126        }
127        return records;
128    }
129
130    @Override
131    public void close() {
132        try {
133            this.inputStream.close();
134        } catch (IOException e) {
135            Logging.error(e);
136        }
137    }
138
139    /**
140     * Check if there is more data to read
141     *
142     * @return {@code true} if there is more data to read
143     * @throws IOException - if an IO error occurs
144     */
145    public boolean hasNext() throws IOException {
146        return this.inputStream.available() > 0;
147    }
148
149    /**
150     * Get the "next" WireType
151     *
152     * @return {@link WireType} expected
153     * @throws IOException - if an IO error occurs
154     */
155    public WireType next() throws IOException {
156        this.inputStream.mark(16);
157        try {
158            return WireType.values()[this.inputStream.read() << 3];
159        } finally {
160            this.inputStream.reset();
161        }
162    }
163
164    /**
165     * Get the next byte
166     *
167     * @return The next byte
168     * @throws IOException - if an IO error occurs
169     */
170    public int nextByte() throws IOException {
171        return this.inputStream.read();
172    }
173
174    /**
175     * Get the next 32 bits ({@link WireType#THIRTY_TWO_BIT})
176     *
177     * @return a byte array of the next 32 bits (4 bytes)
178     * @throws IOException - if an IO error occurs
179     */
180    public byte[] nextFixed32() throws IOException {
181        // 4 bytes == 32 bits
182        return readNextBytes(4);
183    }
184
185    /**
186     * Get the next 64 bits ({@link WireType#SIXTY_FOUR_BIT})
187     *
188     * @return a byte array of the next 64 bits (8 bytes)
189     * @throws IOException - if an IO error occurs
190     */
191    public byte[] nextFixed64() throws IOException {
192        // 8 bytes == 64 bits
193        return readNextBytes(8);
194    }
195
196    /**
197     * Get the next delimited message ({@link WireType#LENGTH_DELIMITED})
198     *
199     * @return The next length delimited message
200     * @throws IOException - if an IO error occurs
201     */
202    public byte[] nextLengthDelimited() throws IOException {
203        int length = convertByteArray(this.nextVarInt(), VAR_INT_BYTE_SIZE).intValue();
204        return readNextBytes(length);
205    }
206
207    /**
208     * Get the next var int ({@code WireType#VARINT})
209     *
210     * @return The next var int ({@code int32}, {@code int64}, {@code uint32}, {@code uint64}, {@code bool}, {@code enum})
211     * @throws IOException - if an IO error occurs
212     */
213    public byte[] nextVarInt() throws IOException {
214        List<Byte> byteList = new ArrayList<>();
215        int currentByte = this.nextByte();
216        while ((byte) (currentByte & MOST_SIGNIFICANT_BYTE) == MOST_SIGNIFICANT_BYTE && currentByte > 0) {
217            // Get rid of the leading bit (shift left 1, then shift right 1 unsigned)
218            byteList.add((byte) (currentByte ^ MOST_SIGNIFICANT_BYTE));
219            currentByte = this.nextByte();
220        }
221        // The last byte doesn't drop the most significant bit
222        byteList.add((byte) currentByte);
223        byte[] byteArray = new byte[byteList.size()];
224        for (int i = 0; i < byteList.size(); i++) {
225            byteArray[i] = byteList.get(i);
226        }
227
228        return byteArray;
229    }
230
231    /**
232     * Read an arbitrary number of bytes
233     *
234     * @param size The number of bytes to read
235     * @return a byte array of the specified size, filled with bytes read (unsigned)
236     * @throws IOException - if an IO error occurs
237     */
238    private byte[] readNextBytes(int size) throws IOException {
239        byte[] bytesRead = new byte[size];
240        for (int i = 0; i < bytesRead.length; i++) {
241            bytesRead[i] = (byte) this.nextByte();
242        }
243        return bytesRead;
244    }
245}