#!/usr/bin/env python

import os
import sys
import glob
import getopt
import exceptions
import string
import traceback
import backup_util
import subprocess
import util
import time

try:
    import sqlite3
except:
    sys.exit("ERROR: %s requires python version 2.6 or greater" %
              (os.path.basename(sys.argv[0])))

DEFAULT_OUTPUT_FILE = "./squashed-%.mbb"
DEFAULT_MAX_DB_SIZE = 512 # Default max size 512MB of a merged database file

def usage(err=0):
    print >> sys.stderr, """
Usage: %s [-o %s] [-s %s] [-v] [-f] incremental_backup_file1 incremental_backup_file2 ...
""" % (os.path.basename(sys.argv[0]), DEFAULT_OUTPUT_FILE, DEFAULT_MAX_DB_SIZE)
    sys.exit(err)

def parse_args(args):
    output_file = DEFAULT_OUTPUT_FILE
    max_db_size = DEFAULT_MAX_DB_SIZE
    verbosity = 0
    split_backup = False
    force_merge = False

    try:
        opts, args = getopt.getopt(args, 'o:s:vf', ['help'])
    except getopt.GetoptError, e:
        usage(e.msg)

    for (o, a) in opts:
        if o == '--help':
            usage()
        elif o == '-o':
            output_file = a
        elif o == '-s':
            max_db_size = int(a)
            split_backup = True
        elif o == '-v':
            verbosity = verbosity + 1
        elif o == '-f':
            force_merge = True
        else:
            usage("unknown option - " + o)

    if not args:
        usage("missing incremental backup files")

    return output_file, max_db_size, args, verbosity, split_backup, force_merge

def findCmd(cmdName):
    cmd_dir = os.path.dirname(sys.argv[0])
    possible = []
    for bin_dir in [cmd_dir, os.path.join(cmd_dir, "..", "..", "bin")]:
        possible = possible + [os.path.join(bin_dir, p) for p in [cmdName, cmdName + '.exe']]
    cmdbin = [p for p in possible if os.path.exists(p)][0]
    return cmdbin

def log(level, *args):
    global verbosity
    if level < verbosity:
       s = ", ".join(list(args))
       print string.rjust(s, (level * 2) + len(s))

def create_single_output_db(db_file_name):
    db = None
    try:
        db = sqlite3.connect(db_file_name)
        db.text_factory = str
        db.executescript("""
        BEGIN;
        CREATE TABLE IF NOT EXISTS cpoint_op
        (vbucket_id integer, cpoint_id integer, seq integer, op text,
        key varchar(250), flg integer, exp integer, cas integer, val blob,
        primary key(vbucket_id, key));
        CREATE TABLE IF NOT EXISTS cpoint_state
        (vbucket_id integer, cpoint_id integer, prev_cpoint_id integer, state varchar(1),
        source varchar(250), updated_at text,
        primary key(vbucket_id, cpoint_id));
        COMMIT;
        """)
    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        exit("ERROR: " + str(e))
    return db

def merge_incremental_backup_files(backup_files, single_output_file):
    try:
        output_db = create_single_output_db(single_output_file)
        for bfile in backup_files:
            log(1, "Incremental backup file: \"%s\"" % bfile)
            backup_db = sqlite3.connect(bfile)
            backup_db.text_factory = str
            bd_cursor = backup_db.cursor()
            bd_cursor.arraysize = 5000
            ## Insert checkpoint_state records into the output db file.
            bd_cursor.execute("SELECT vbucket_id, cpoint_id, prev_cpoint_id, state, " \
                              "source, updated_at FROM cpoint_state " \
                              "WHERE state = \"closed\"")
            checkpoint_states = bd_cursor.fetchall()
            copy_checkpoint_state_records(checkpoint_states, output_db)

            op_records = []
            ## Insert checkpoint_operation records into the output db file.
            bd_cursor.execute("SELECT cpoint_op.vbucket_id, cpoint_op.cpoint_id, seq, " \
                              "op, key, flg, exp, cas, val " \
                              "FROM cpoint_state JOIN cpoint_op ON " \
                              "(cpoint_op.vbucket_id = cpoint_state.vbucket_id AND " \
                              "cpoint_op.cpoint_id = cpoint_state.cpoint_id) " \
                              "WHERE cpoint_state.state = \"closed\" " \
                              "ORDER BY cpoint_op.cpoint_id DESC")
            while True:
                op_records = bd_cursor.fetchmany(bd_cursor.arraysize)
                if op_records == []:
                    break
                if copy_checkpoint_op_records(op_records, output_db) != True:
                    sys.exit("Error in merging \"%s\"" % bfile)
            backup_db.close()
    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        exit("ERROR: " + str(e))
    finally:
        if output_db:
            output_db.close()

def create_split_db(db_file_name, max_db_size):
    db = None
    max_db_size = max_db_size * 1024 * 1024 # Convert MB to bytes
    try:
        db = sqlite3.connect(db_file_name)
        db.text_factory = str
        db.executescript("""
        BEGIN;
        CREATE TABLE cpoint_op
        (vbucket_id integer, cpoint_id integer, seq integer, op text,
        key varchar(250), flg integer, exp integer, cas integer, val blob,
        primary key(vbucket_id, key));
        CREATE TABLE cpoint_state
        (vbucket_id integer, cpoint_id integer, prev_cpoint_id integer, state varchar(1),
        source varchar(250), updated_at text,
        primary key(vbucket_id, cpoint_id));
        COMMIT;
        """)
        db_page_size = db.execute("pragma page_size").fetchone()[0]
        db_max_page_count = max_db_size / db_page_size
        db.execute("pragma max_page_count=%d" % (db_max_page_count))
    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        exit("ERROR: " + str(e))
    return db

def copy_checkpoint_state_records(checkpoint_states, db):
    c = db.cursor()
    stmt = "INSERT OR IGNORE into cpoint_state" \
           "(vbucket_id, cpoint_id, prev_cpoint_id, state, source, updated_at)" \
           " VALUES (?, ?, ?, ?, ?, ?)"
    for cstate in checkpoint_states:
        c.execute(stmt, (cstate[0], cstate[1], cstate[2], cstate[3], cstate[4], cstate[5]))
    db.commit()
    c.close()

def copy_checkpoint_op_records(op_records, db):
    result = True
    c = db.cursor()
    stmt = "INSERT OR IGNORE into cpoint_op" \
           "(vbucket_id, cpoint_id, seq, op, key, flg, exp, cas, val)" \
           " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
    try:
        for oprecord in op_records:
            c.execute(stmt, (oprecord[0], oprecord[1], oprecord[2], oprecord[3], oprecord[4],
                             oprecord[5], oprecord[6], oprecord[7], sqlite3.Binary(oprecord[8])))
    except sqlite3.Error, e: ## Can't find the better exeception code for database full error.
        log(1, "The database size exceeds the max size allowed: " + e.args[0])
        result = False
    if result == True:
        db.commit()
    c.close()
    return result

def split_single_merged_db_file(single_output_file, split_output_file, max_db_size):
    try:
        ## Find the next split output file to start with
        next_split_output_file = util.expand_file_pattern(split_output_file)
        split_db = create_split_db(next_split_output_file, max_db_size)
        log(1, "Merged database file: \"%s\"" % next_split_output_file)
        merged_db = sqlite3.connect(single_output_file)
        merged_db.text_factory = str
        md_cursor = merged_db.cursor()
        md_cursor.arraysize = 5000

        ## Copy checkpoint_state records into the first split db.
        ## Note that the number of checkpoint_state records is usually small even in the merged db.
        md_cursor.execute("select vbucket_id, cpoint_id, prev_cpoint_id, state, source, updated_at " \
                          "from cpoint_state")
        checkpoint_states = md_cursor.fetchall()
        copy_checkpoint_state_records(checkpoint_states, split_db)

        op_records = []
        ## Copy checkpoint_operation records into the multiple split database files.
        md_cursor.execute("select vbucket_id, cpoint_id, seq, op, key, flg, exp, cas, val " \
                          "from cpoint_op")
        while True:
            op_records = md_cursor.fetchmany(md_cursor.arraysize)
            if op_records == []:
                break
            if copy_checkpoint_op_records(op_records, split_db) != True:
                ## The current split database size exceeds the max size allowed.
                ## Create the next split database and continue to copy records.
                try:
                    split_db.rollback()
                except sqlite3.Error, e: ## Can't find the better error code for rollback failure.
                    log(1, "Insertion transaction was already rollbacked: " + e.args[0])
                split_db.close()
                next_split_output_file = util.expand_file_pattern(split_output_file)
                split_db = create_split_db(next_split_output_file, max_db_size)
                log(1, "Merged database file: \"%s\"" % next_split_output_file)
                copy_checkpoint_state_records(checkpoint_states, split_db)
                copy_checkpoint_op_records(op_records, split_db)

        merged_db.close()
        split_db.close()
    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        exit("ERROR: " + str(e))

def main():
    global verbosity

    output_file, max_db_size, input_files, verbosity, split_backup, force_merge = parse_args(sys.argv[1:])
    log(1, "incremental backup files = " + ' '.join(input_files))
    log(1, "output backup file = %s" % output_file)
    if split_backup:
        log(1, "max size of a single merged database = %d MB" % max_db_size)
    log(1, "verbosity = " + str(verbosity) + "\n")

    backup_files = []
    for file in input_files:
        bfiles = glob.glob(file)
        if len(bfiles) == 0:
            err_msg = "Backup file '%s' does not exist!!!" % (file)
            exit(err_msg)
        backup_files.extend(bfiles)

    ## Check if there are any missing checkpoints in the backup files
    if os.path.exists(output_file):
        backup_files.insert(0, output_file)
        backup_files = backup_util.validate_incremental_backup_files(backup_files)
        if backup_files[0] == output_file:
            backup_files.remove(output_file)
        else:
            err_msg = "There are missing checkpoints between the " \
                      "output file and input backup files"
            exit(err_msg)
    else:
        backup_files = backup_util.validate_incremental_backup_files(backup_files)

    if split_backup:
        timestamp = str(time.time())
        single_output_file = "./" + timestamp + ".mbb"
        rindex = output_file.rfind('/')
        if rindex != -1:
            single_output_file = output_file[0:rindex+1] + timestamp + '.mbb'
    else:
            single_output_file = output_file

    if len(backup_files) == 1 and not force_merge:
        if not split_backup:
            err_msg = "You should provide atleast two files for merging" \
                    "without splitting"
            exit(err_msg)
        single_output_file = backup_files[0]
    else:
        ## Merge all incremental backup files into a single database file
        merge_incremental_backup_files(backup_files, single_output_file)

    if split_backup:
        ## Split the single merged database file into multiple database files,
        ## so that each split file size does not exceed the max size provided.
        split_single_merged_db_file(single_output_file, output_file, max_db_size)
        os.unlink(single_output_file)

    log(1, "\n  Merging incremental backup files are completed.\n")

if __name__ == '__main__':
    main()
