/*
 * Decompiled with CFR 0.152.
 */
package ghidra.features.bsim.query.file;

import generic.lsh.vector.LSHVector;
import ghidra.features.bsim.query.client.tables.CachedStatement;
import ghidra.features.bsim.query.client.tables.SQLComplexTable;
import ghidra.features.bsim.query.description.VectorResult;
import ghidra.features.bsim.query.elastic.Base64Lite;
import ghidra.features.bsim.query.elastic.Base64VectorFactory;
import ghidra.features.bsim.query.file.VectorStore;
import ghidra.features.bsim.query.file.VectorStoreEntry;
import java.io.IOException;
import java.io.StringReader;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;

public class H2VectorTable
extends SQLComplexTable {
    public static final String TABLE_NAME = "h2_vectable";
    private final Base64VectorFactory vectorFactory;
    private final VectorStore vectorStore;
    private final CachedStatement<PreparedStatement> insert_stmt = new CachedStatement();
    private final CachedStatement<PreparedStatement> select_by_rowid_stmt = new CachedStatement();
    private final CachedStatement<PreparedStatement> select_id_by_hash_stmt = new CachedStatement();
    private final CachedStatement<PreparedStatement> update_by_hash_stmt = new CachedStatement();
    private final CachedStatement<PreparedStatement> select_count_by_rowid_stmt = new CachedStatement();
    private final CachedStatement<PreparedStatement> update_by_rowid_stmt = new CachedStatement();

    public H2VectorTable(Base64VectorFactory vectorFactory, VectorStore vectorStore) {
        super(TABLE_NAME, "id");
        this.vectorFactory = vectorFactory;
        this.vectorStore = vectorStore;
    }

    @Override
    public void close() {
        this.insert_stmt.close();
        this.select_by_rowid_stmt.close();
        this.select_id_by_hash_stmt.close();
        this.update_by_hash_stmt.close();
        this.select_count_by_rowid_stmt.close();
        this.update_by_rowid_stmt.close();
        super.close();
    }

    @Override
    public void create(Statement st) throws SQLException {
        st.executeUpdate("CREATE TABLE h2_vectable(id SERIAL PRIMARY KEY, count INTEGER, vec_hash BIGINT, vec CLOB)");
        st.executeUpdate("CREATE UNIQUE INDEX h2_vectable_index ON h2_vectable (vec_hash)");
    }

    @Override
    public void drop(Statement st) throws SQLException {
        this.vectorStore.invalidate();
        st.executeUpdate("DROP INDEX h2_vectable_index");
        super.drop(st);
    }

    @Override
    public long insert(Object ... arguments) throws SQLException {
        long id;
        if (arguments == null || arguments.length != 2) {
            throw new IllegalArgumentException("Insert method for H2VectorTable accepts two arguments: count(int) and LSHVector");
        }
        int count = (Integer)arguments[0];
        LSHVector vec = (LSHVector)arguments[1];
        PreparedStatement s = this.insert_stmt.prepareIfNeeded(() -> this.db.prepareStatement("INSERT INTO h2_vectable (count,vec_hash,vec) VALUES(?,?,?)", 1));
        StringBuilder vecBuf = new StringBuilder();
        vec.saveBase64(vecBuf, Base64Lite.encode);
        s.setInt(1, count);
        s.setLong(2, vec.calcUniqueHash());
        s.setString(3, vecBuf.toString());
        if (s.executeUpdate() != 1) {
            throw new SQLException("Insert failed for vector table");
        }
        try (ResultSet rs = s.getGeneratedKeys();){
            if (!rs.next()) {
                throw new SQLException("Unable to obtain vector id for insert");
            }
            id = rs.getLong(1);
        }
        this.vectorStore.update(new VectorStoreEntry(id, vec, count, this.vectorFactory.getSelfSignificance(vec)));
        return id;
    }

    public Map<Long, VectorStoreEntry> readVectors() throws SQLException {
        char[] vectorDecodeBuffer = Base64VectorFactory.allocateBuffer();
        HashMap<Long, VectorStoreEntry> map = new HashMap<Long, VectorStoreEntry>();
        try (Statement st = this.db.createStatement();
             ResultSet rs = st.executeQuery("SELECT id,count,vec FROM h2_vectable");){
            while (rs.next()) {
                long id = rs.getLong(1);
                int count = rs.getInt(2);
                StringReader r = new StringReader(rs.getString(3));
                LSHVector vec = this.vectorFactory.restoreVectorFromBase64(r, vectorDecodeBuffer);
                VectorStoreEntry entry = new VectorStoreEntry(id, vec, count, this.vectorFactory.getSelfSignificance(vec));
                map.put(id, entry);
            }
        }
        catch (IOException e) {
            throw new SQLException(e);
        }
        return map;
    }

    public VectorResult queryVectorById(long id) throws SQLException {
        VectorStoreEntry entry = this.vectorStore.getVectorById(id);
        if (entry != null) {
            return new VectorResult(id, entry.count(), 0.0, 0.0, entry.vec());
        }
        PreparedStatement s = this.select_by_rowid_stmt.prepareIfNeeded(() -> this.db.prepareStatement("SELECT id,count,vec FROM h2_vectable WHERE id = ?"));
        s.setLong(1, id);
        try (ResultSet rs = s.executeQuery();){
            VectorResult rowres;
            if (!rs.next()) {
                throw new SQLException("Bad vector table rowid");
            }
            char[] vectorDecodeBuffer = Base64VectorFactory.allocateBuffer();
            try {
                rowres = new VectorResult();
                rowres.vectorid = rs.getLong(1);
                rowres.hitcount = rs.getInt(2);
                StringReader r = new StringReader(rs.getString(3));
                rowres.vec = this.vectorFactory.restoreVectorFromBase64(r, vectorDecodeBuffer);
            }
            catch (IOException e) {
                throw new SQLException(e.getMessage());
            }
            VectorResult vectorResult = rowres;
            return vectorResult;
        }
    }

    private int queryVectorCountById(long id) throws SQLException {
        PreparedStatement s = this.select_count_by_rowid_stmt.prepareIfNeeded(() -> this.db.prepareStatement("SELECT count FROM h2_vectable WHERE id = ?"));
        s.setLong(1, id);
        try (ResultSet rs = s.executeQuery();){
            if (!rs.next()) {
                throw new SQLException("Bad vector table rowid");
            }
            int n = rs.getInt(1);
            return n;
        }
    }

    public long updateVector(LSHVector vec, int countDiff) throws SQLException {
        int count;
        long id;
        if (countDiff <= 0) {
            throw new IllegalArgumentException("Invalid countDiff: " + countDiff);
        }
        PreparedStatement s = this.update_by_hash_stmt.prepareIfNeeded(() -> this.db.prepareStatement("UPDATE h2_vectable SET count = count + ? WHERE vec_hash = ?"));
        long vecHash = vec.calcUniqueHash();
        s.setInt(1, countDiff);
        s.setLong(2, vecHash);
        int rc = s.executeUpdate();
        if (rc == 0) {
            return this.insert(countDiff, vec);
        }
        if (rc > 1) {
            throw new SQLException("Unexpected updated row count: " + rc);
        }
        s = this.select_id_by_hash_stmt.prepareIfNeeded(() -> this.db.prepareStatement("SELECT id, count FROM h2_vectable WHERE vec_hash = ?"));
        s.setLong(1, vecHash);
        try (ResultSet rs = s.executeQuery();){
            if (!rs.next()) {
                throw new SQLException("Unknown vector hash");
            }
            id = rs.getLong(1);
            count = rs.getInt(2);
        }
        this.vectorStore.update(new VectorStoreEntry(id, vec, count, this.vectorFactory.getSelfSignificance(vec)));
        return id;
    }

    public int deleteVector(long id, int countDiff) throws SQLException {
        if (countDiff <= 0) {
            throw new IllegalArgumentException("Invalid countDiff: " + countDiff);
        }
        PreparedStatement s = this.update_by_rowid_stmt.prepareIfNeeded(() -> this.db.prepareStatement("UPDATE h2_vectable SET count = count - ? WHERE id = ? AND count >= ?"));
        s.setInt(1, countDiff);
        s.setLong(2, id);
        s.setInt(3, countDiff);
        int rc = s.executeUpdate();
        if (rc == 0) {
            return -1;
        }
        if (rc > 1) {
            throw new SQLException("Unexpected updated row count: " + rc);
        }
        int count = this.queryVectorCountById(id);
        if (count > 0) {
            this.vectorStore.update(id, count);
            return 0;
        }
        this.delete(id);
        return 1;
    }

    @Override
    public int delete(long id) throws SQLException {
        int rc = super.delete(id);
        this.vectorStore.delete(id);
        return rc;
    }
}

