/*
 * Copyright (c) 2013-2014, RIKEN, Japan
 * All rights reserved.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 * 
 * This program is a derived work of GROMACS xdrfile project version 1.1.1.
 * http://www.gromacs.org/Downloads
 * 
 * The original code is written in C and licensed as below.
 */

/* -*- mode: c; tab-width: 4; indent-tabs-mode: t; c-basic-offset: 4 -*- 
 *
 * $Id$
 *
 * Copyright (c) Erik Lindahl, David van der Spoel 2003,2004.
 * Coordinate compression (c) by Frans van Hoesel. 
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 */
package jp.riken.lib.xdrfile;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class Xdrfile {

    private BufferedInputStream inputStream;
    private ByteBuffer buffer;
    private static final int BUFFER_SIZE = 8192;
    private int[] buf1;
    private int buf1size;
    private byte[] buf2;
    private int buf2size;
    private long[] buf3;

    public Xdrfile() {
        buffer = ByteBuffer.allocate(BUFFER_SIZE);
        buffer.order(ByteOrder.BIG_ENDIAN);
        buf3 = new long[3];
    }

    public void openRead(File file) throws XdrfileException {
        try {
            inputStream = new BufferedInputStream(new FileInputStream(file));
        } catch (FileNotFoundException e) {
            throw new XdrfileException(XdrfileException.CANNOT_OPEN, e);
        }
    }

    public void close() throws XdrfileException {
        if (inputStream != null) {
            try {
                inputStream.close();
            } catch (IOException e) {
                throw new XdrfileException(XdrfileException.IO_ERROR, e);
            }
        }
    }

    public byte readByte() throws XdrfileException {
        try {
            int num = inputStream.read(buffer.array(), 0, 1);
            if (num < 1) {
                throw new XdrfileException(XdrfileException.EOF);
            }
            return buffer.get(0);
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    public void readByte(byte[] data, int n) throws XdrfileException {
        try {
            for (int i = 0; i < n; i++) {
                int num = inputStream.read(buffer.array(), 0, 1);
                if (num < 1) {
                    throw new XdrfileException(XdrfileException.EOF);
                }
                data[i] = buffer.get(0);
            }
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    public int readInt() throws XdrfileException {
        try {
            int num = inputStream.read(buffer.array(), 0, 4);
            if (num < 4) {
                throw new XdrfileException(XdrfileException.EOF);
            }
            return buffer.getInt(0);
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    public void readInt(int[] data, int n) throws XdrfileException {
        try {
            for (int i = 0; i < n; i++) {
                int num = inputStream.read(buffer.array(), 0, 4);
                if (num < 4) {
                    throw new XdrfileException(XdrfileException.EOF);
                }
                data[i] = buffer.getInt(0);
            }
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    public float readFloat() throws XdrfileException {
        try {
            int num = inputStream.read(buffer.array(), 0, 4);
            if (num < 4) {
                throw new XdrfileException(XdrfileException.EOF);
            }
            return buffer.getFloat(0);
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    public void readFloat(float[] data, int n) throws XdrfileException {
        try {
            for (int i = 0; i < n; i++) {
                int num = inputStream.read(buffer.array(), 0, 4);
                if (num < 4) {
                    throw new XdrfileException(XdrfileException.EOF);
                }
                data[i] = buffer.getFloat(0);
            }
        } catch (IOException e) {
            throw new XdrfileException(XdrfileException.IO_ERROR, e);
        }
    }

    private static final int[] MAGICINTS = {
        0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 10, 12, 16, 20, 25, 32, 40, 50, 64,
        80, 101, 128, 161, 203, 256, 322, 406, 512, 645, 812, 1024, 1290,
        1625, 2048, 2580, 3250, 4096, 5060, 6501, 8192, 10321, 13003, 
        16384, 20642, 26007, 32768, 41285, 52015, 65536,82570, 104031, 
        131072, 165140, 208063, 262144, 330280, 416127, 524287, 660561, 
        832255, 1048576, 1321122, 1664510, 2097152, 2642245, 3329021, 
        4194304, 5284491, 6658042, 8388607, 10568983, 13316085, 16777216 
    };
    private static final int FIRSTIDX = 9;
    //private static final int LASTIDX = MAGICINTS.length;

    public void decompressCoordFloat(XdrfileFrame frame)
            throws XdrfileException {
        int size = readInt();
        if (frame.getNatoms() < size) {
            throw new XdrfileException(XdrfileException.INCONSISTENT_DATA);
        }
        int size3 = size * 3;
        if (size3 > buf1size) {
            buf1size = size3;
            buf1 = new int[buf1size];
            buf2size = (int) (size3 * 1.2 * 4);
            buf2 = new byte[buf2size];
        }
        if (frame.getNatoms() <= 9) {
            readFloat(frame.getX(), size3);
            return;
        }

        frame.setPrec(readFloat());
        int[] minint = new int[3];
        int[] maxint = new int[3];
        long[] sizeint = new long[3];
        long[] sizesmall = new long[3];
        int[] bitsizeint = new int[3];
        int bitsize=0;
        bitsizeint[0] = 0;
        bitsizeint[1] = 0;
        bitsizeint[2] = 0;

        buf3[0] = buf3[1] = buf3[2] = 0;
        readInt(minint, 3);
        readInt(maxint, 3);
        sizeint[0] = (long) maxint[0] - (long) minint[0] + 1L;
        sizeint[1] = (long) maxint[1] - (long) minint[1] + 1L;
        sizeint[2] = (long) maxint[2] - (long) minint[2] + 1L;
        if ((sizeint[0] | sizeint[1] | sizeint[2]) > 0xffffffL) {
            bitsizeint[0] = sizeOfInt(sizeint[0]);
            bitsizeint[1] = sizeOfInt(sizeint[1]);
            bitsizeint[2] = sizeOfInt(sizeint[2]);
            bitsize = 0; // flag the use of large sizes
        } else {
            bitsize = sizeOfInts(3, sizeint);
        }

        int smallidx = readInt();
        int tmp = smallidx + 8;
        //int maxidx = (LASTIDX < tmp) ? LASTIDX : tmp;
        //int minidx = maxidx - 8; // often this equal smallidx
        tmp = smallidx - 1;
        tmp = (FIRSTIDX > tmp) ? FIRSTIDX : tmp;
        int smaller = MAGICINTS[tmp] / 2;
        int smallnum = MAGICINTS[smallidx] / 2;
        sizesmall[0] = sizesmall[1] = sizesmall[2] = MAGICINTS[smallidx];
        //int larger = MAGICINTS[maxidx];

        // buf3[0] holds the length in bytes
        buf3[0] = readInt();
        readOpaque((int) buf3[0]);
        buf3[0] = buf3[1] = buf3[2] = 0;
        float[] x = frame.getX();
        int idx = 0;
        float invPrec = 1 / frame.getPrec();
        int run = 0;
        int i = 0;
        int[] prev = new int[3];
        while (i < size) {
            int ix = i * 3;
            int iy = i * 3 + 1;
            int iz = i * 3 + 2;
            if (bitsize == 0) {
                buf1[ix] = decodeBits(bitsizeint[0]);
                buf1[iy] = decodeBits(bitsizeint[1]);
                buf1[iz] = decodeBits(bitsizeint[2]);
            } else {
                decodeInts(3, bitsize, sizeint, ix);
            }

            i++;
            buf1[ix] += minint[0];
            buf1[iy] += minint[1];
            buf1[iz] += minint[2];

            prev[0] = buf1[ix];
            prev[1] = buf1[iy];
            prev[2] = buf1[iz];

            int flag = decodeBits(1);
            int is_smaller = 0;
            if (flag == 1) {
                run = decodeBits(5);
                is_smaller = run % 3;
                run -= is_smaller;
                is_smaller--;
            }
            if (run > 0) {
                ix += 3;
                iy += 3;
                iz += 3;
                for (int k = 0; k < run; k += 3) {
                    decodeInts(3, smallidx, sizesmall, ix);
                    i++;
                    buf1[ix] += prev[0] - smallnum;
                    buf1[iy] += prev[1] - smallnum;
                    buf1[iz] += prev[2] - smallnum;
                    if (k == 0) {
                        // interchange first with second atom for better
                        // compression of water molecules
                        tmp = buf1[ix]; buf1[ix] = prev[0]; prev[0] = tmp;
                        tmp = buf1[iy]; buf1[iy] = prev[1]; prev[1] = tmp;
                        tmp = buf1[iz]; buf1[iz] = prev[2]; prev[2] = tmp;
                        x[idx++] = prev[0] * invPrec;
                        x[idx++] = prev[1] * invPrec;
                        x[idx++] = prev[2] * invPrec;
                    } else {
                        prev[0] = buf1[ix];
                        prev[1] = buf1[iy];
                        prev[2] = buf1[iz];
                    }
                    x[idx++] = buf1[ix] * invPrec;
                    x[idx++] = buf1[iy] * invPrec;
                    x[idx++] = buf1[iz] * invPrec;
                }
            } else {
                x[idx++] = buf1[ix] * invPrec;
                x[idx++] = buf1[iy] * invPrec;
                x[idx++] = buf1[iz] * invPrec;
            }
            smallidx += is_smaller;
            if (is_smaller < 0) {
                smallnum = smaller;
                if (smallidx > FIRSTIDX) {
                    smaller = MAGICINTS[smallidx - 1] / 2;
                } else {
                    smaller = 0;
                }
            } else if (is_smaller > 0) {
                smaller = smallnum;
                smallnum = MAGICINTS[smallidx] / 2;
            }
            sizesmall[0] = sizesmall[1] = sizesmall[2] = MAGICINTS[smallidx];
        }
    }

    private int sizeOfInt(long value) {
        long num = 1;
        int numBits = 0;
        while (value >= num && numBits < 32) {
            numBits++;
            num <<= 1;
        }
        return numBits;
    }

    private int sizeOfInts(int numInts, long[] values) {
        int[] bytes = new int[32];
        bytes[0] = 1;
        int numBytes = 1;
        int bytecnt = 0;
        long tmp = 0;
        for (int i = 0; i < numInts; i++) {
            tmp = 0;
            for (bytecnt = 0; bytecnt < numBytes; bytecnt++) {
                tmp = bytes[bytecnt] * values[i] + tmp;
                bytes[bytecnt] = (int) (tmp & 0xffL);
                tmp >>= 8;
            }
            while (tmp != 0) {
                bytes[bytecnt++] = (int) (tmp & 0xffL);
                tmp >>= 8;
            }
            numBytes = bytecnt;
        }
        int numBits = 0;
        int num = 1;
        numBytes--;
        while (bytes[numBytes] >= num) {
            numBits++;
            num *= 2;
        }
        return numBits + numBytes * 8;
    }

    private void readOpaque(int cnt) {
        if (cnt == 0) {
            return;
        }
        byte[] crud = new byte[4];
        int rndup = cnt % 4;
        if (rndup > 0) {
            rndup = 4 - rndup;
        }
        readByte(buf2, cnt);
        if (rndup > 0) {
            readByte(crud, rndup);
        }
    }

    private int decodeBits(int numBits) {
        int cnt = (int) buf3[0];
        long lastbits = buf3[1];
        long lastbyte = buf3[2];
        int mask = (1 << numBits) -1;
        int num = 0;
        while (numBits >= 8) {
            lastbyte = (lastbyte << 8) | asUnsignedByte(cnt++);
            lastbyte &= 0xffffffff;
            num |= (lastbyte >> lastbits) << (numBits - 8);
            numBits -= 8;
        }
        if (numBits > 0) {
            if (lastbits < numBits) {
                lastbits += 8;
                lastbyte = (lastbyte << 8) | asUnsignedByte(cnt++);
                lastbyte &= 0xffffffff;
            }
            lastbits -= numBits;
            num |= (lastbyte >> lastbits) & ((1 << numBits) - 1);
        }
        num &= mask;
        buf3[0] = cnt;
        buf3[1] = lastbits;
        buf3[2] = lastbyte;
        return num; 
    }

    private int asUnsignedByte(int cnt) {
        return (buf2[cnt] >= 0) ? buf2[cnt] : buf2[cnt] + 256;
    }

    private void decodeInts(int numInts, int numBits, long[] sizes, int offset) {
        int[] bytes = new int[32];
        bytes[1] = bytes[2] = bytes[3] = 0;
        int numBytes = 0;
        while (numBits > 8) {
            bytes[numBytes++] = decodeBits(8);
            numBits -= 8;
        }
        if (numBits > 0) {
            bytes[numBytes++] = decodeBits(numBits);
        }
        for (int i = numInts - 1; i > 0; i--) {
            int num = 0;
            for (int j = numBytes - 1; j >= 0; j--) {
                num = (num << 8) | bytes[j];
                int p = (int) (num / sizes[i]);
                bytes[j] = p;
                num = (int) (num - p * sizes[i]);
            }
            buf1[offset + i] = num;
        }
        buf1[offset] = bytes[0] | (bytes[1] << 8) | (bytes[2] << 16)
                | (bytes[3] << 24);
    }
}
