#!/usr/bin/env python

import os
import sys
import glob
import getopt
import exceptions
import datetime
import select
import string
import struct
import time
import traceback

import mc_bin_client
import memcacheConstants
import util

try:
    import sqlite3
except:
    sys.exit("ERROR: %s requires python version 2.6 or greater" %
              (os.path.basename(sys.argv[0])))

MBB_VERSION = "1" # sqlite pragma user version.

DEFAULT_PORT      = "11210"
DEFAULT_HOST_PORT = ["127.0.0.1", DEFAULT_PORT]
DEFAULT_FILE      = "./backup-%.mbb"
DEFAULT_TXN_SIZE  = 5000
DEFAULT_MAX_BACKUP_SIZE = 512 # Default max size 512MB of a single incremental backup file generated

def usage(err=0):
    print >> sys.stderr, """
Usage: %s [-h %s[:%s]] [-o %s] [-s %s] [-T num_secs] [-r] [-c] [-d] [-t transaction_size] [-v] tap_name|db_file
""" % (os.path.basename(sys.argv[0]),
       DEFAULT_HOST_PORT[0], DEFAULT_HOST_PORT[1], DEFAULT_FILE, DEFAULT_MAX_BACKUP_SIZE)
    sys.exit(err)

def parse_args(args):
    host_port = DEFAULT_HOST_PORT
    file      = DEFAULT_FILE
    timeout   = 0
    verbosity = 0
    only_name = False
    only_check = False
    txn_size  = DEFAULT_TXN_SIZE
    max_backup_size = DEFAULT_MAX_BACKUP_SIZE
    split_backup = False
    deduplicate = False

    try:
        opts, args = getopt.getopt(args, 'h:o:s:T:t:rcdv', ['help'])
    except getopt.GetoptError, e:
        usage(e.msg)

    for (o, a) in opts:
        if o == '--help':
            usage()
        elif o == '-h':
            host_port = a.split(':')
            if len(host_port) < 2:
                host_port = [a, DEFAULT_PORT]
        elif o == '-o':
            file = a
        elif o == '-s':
            max_backup_size = int(a)
            split_backup = True
        elif o == '-T':
            timeout = float(a)
        elif o == '-r':
            only_name = True # Only register the name and exit.
        elif o == '-c':
            only_check = True # Only check the incremental backup file by showing the list of checkpoints.
        elif o == '-d':
            deduplicate = True
        elif o == '-t':
            txn_size = int(a)
        elif o == '-v':
            verbosity = verbosity + 1
        else:
            usage("unknown option - " + o)

    if not args:
        usage("missing the registered client name or the incremental backup file name")
    if len(args) != 1:
        usage("incorrect number of arguments - only one registered name or backup file name needed")

    return host_port, file, args[0], timeout, only_name, only_check, txn_size, verbosity,\
           max_backup_size, split_backup, deduplicate

def log(level, *args):
    global verbosity
    if level < verbosity:
       s = ", ".join(list(args))
       print string.rjust(s, (level * 2) + len(s))

def main():
    global verbosity

    host_port, file, name, timeout, only_name, only_check, txn_size, verbosity,\
               max_backup_size, split_backup, deduplicate = parse_args(sys.argv[1:])
    log(1, "host_port = " + ':'.join(host_port))
    log(1, "verbosity = " + str(verbosity))
    log(1, "only_name = " + str(only_name))
    log(1, "only_check = " + str(only_check))
    log(1, "timeout = " + str(timeout))
    log(1, "transaction_size = " + str(txn_size))
    log(1, "file = " + file)
    if split_backup:
        log(1, "max size of a single backup file = %d MB" % max_backup_size)
    if only_check:
        log(1, "incremental_backup_file = " + name)
    else:
        log(1, "registered_client_name = " + name)

    mc = None

    try:
        if only_check:
            check_incremental_backup_file(name)
            sys.exit(0)

        mc = mc_bin_client.MemcachedClient(host_port[0], int(host_port[1]))

        ext, val = encodeTAPConnectOpts({
          memcacheConstants.TAP_FLAG_CHECKPOINT: '',
          memcacheConstants.TAP_FLAG_SUPPORT_ACK: '',
          memcacheConstants.TAP_FLAG_REGISTERED_CLIENT: 0x01, # "value > 0" means "closed checkpoints only"
          memcacheConstants.TAP_FLAG_BACKFILL: 0xffffffff
        })

        mc._sendCmd(memcacheConstants.CMD_TAP_CONNECT, name, val, 0, ext)

        if only_name:
            cmd, opaque, cas, vbucketId, key, ext, val = readTap(mc)
            if cmd == memcacheConstants.CMD_TAP_OPAQUE:
                sys.exit(0);
            sys.exit("ERROR: could not register name: " + name)

        loop(mc, file, max_backup_size, split_backup, deduplicate, timeout, txn_size, ':'.join(host_port) + '-' + name)

    except NameError as ne:
        sys.exit("ERROR: " + str(ne))
    except Exception as e:
        if verbosity > 1:
            traceback.print_exc(file=sys.stdout)
            print e
        sys.exit("ERROR: " + str(e))
    finally:
        if mc:
           mc.close()

def check_incremental_backup_file(file_name):
    db = None
    try:
        if os.path.exists(file_name) == False:
            sys.exit("ERROR: file does not exist: " + file_name)

        db = sqlite3.connect(file_name)
        db.text_factory = str
        c = db.cursor()
        c.execute("select vbucket_id, cpoint_id, state, updated_at from cpoint_state " +
                  "order by vbucket_id asc, updated_at asc")
        print "\n vbucket_id   checkpoint_id   state    updated_at"
        print "------------------------------------------------------"
        for row in c:
            updated = time.strptime(str(row[3]), "%Y%m%d%H%M%S")
            updated_at = datetime.datetime.fromtimestamp(time.mktime(updated)).strftime("%Y/%m/%d %H:%M:%S")
            print " %s   %s   %s   %s" % (str(row[0]).center(10), str(row[1]).center(13), str(row[2]).center(6), updated_at.center(19))

    except Exception as e:
        if verbosity > 1:
            traceback.print_exc(file=sys.stdout)
            print e
        sys.exit("ERROR: " + str(e))
    finally:
        if db:
            db.close()

def get_available_backup_file_name(file_name_pattern):
    next_backup_file = util.expand_file_pattern(file_name_pattern)
    log(1, "Next backup file = " + next_backup_file)
    if os.path.exists(next_backup_file):
        sys.exit("ERROR: file exists already: " + next_backup_file)
    return next_backup_file

def add_record(cursor, stmt, fields):
    result = True
    try:
        cursor.execute(stmt, fields)
    except sqlite3.Error, e: ## Can't find the better exeception code for database full error.
        log(1, "Backup database size exceeds the max size allowed: " + e.args[0])
        result = False
    return result

def update_checkpoint_state_in_split_backups(split_backup_files, chk_stmt, fields):
    for bfile in split_backup_files:
        db = sqlite3.connect(bfile)
        db.text_factory = str
        cursor = db.cursor()
        cursor.execute(chk_stmt, fields)
        db.commit()
        cursor.close()
        db.close()

def loop(mc, file, max_backup_size, split_backup, deduplicate, timeout, txn_size, source):
    split_backup_files = []
    next_backup_file = get_available_backup_file_name(file)
    db = create_backup_db(next_backup_file, max_backup_size, split_backup, deduplicate)

    vbmap = {} # Key is vbucketId, value is [checkpointId, seq].

    cmdInfo = {
        memcacheConstants.CMD_TAP_MUTATION: ('mutation', 'm'),
        memcacheConstants.CMD_TAP_DELETE: ('delete', 'd'),
        memcacheConstants.CMD_TAP_FLUSH: ('flush', 'f'),
    }

    try:
        sinput = [mc.s]
        c = db.cursor()
        update_count = 0
        tap_stmt = "INSERT OR REPLACE into cpoint_op" \
                   "(vbucket_id, cpoint_id, seq, op, key, flg, exp, cas, val)" \
                   " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
        chk_start_stmt = "INSERT into cpoint_state" \
                         "(vbucket_id, cpoint_id, prev_cpoint_id, state, source, updated_at)" \
                         " VALUES (?, ?, ?, \"open\", ?, ?)"
        chk_end_stmt = "UPDATE cpoint_state" \
                       " SET state=\"closed\", updated_at=?" \
                       " WHERE vbucket_id=? AND cpoint_id=? AND state=\"open\" AND source=?"
        del_chk_start_stmt = "DELETE FROM cpoint_state WHERE state=\"open\""
        op_records = [] ## Used for keeping the list of records in the current transaction

        while True:
            if timeout > 0:
                iready, oready, eready = select.select(sinput, [], [], timeout)
                if (not iready) and (not oready) and (not eready):
                    log(1, "EXIT: timeout after " + str(timeout) + " seconds of inactivity")
                    sys.exit(0)

            cmd, opaque, cas, vbucketId, key, ext, val = readTap(mc)
            log(2, "got " + str(cmd) + " k:" + key + " vlen:" + str(len(val))
                    + " elen:" + str(len(ext))
                    + " vbid:" + str(vbucketId))
            if len(val) > 0 and len(val) < 64:
                log(2, "  val: <<" + str(val) + ">>")

            needAck = False

            if (cmd == memcacheConstants.CMD_TAP_MUTATION or
                cmd == memcacheConstants.CMD_TAP_DELETE or
                cmd == memcacheConstants.CMD_TAP_FLUSH):
                cmdName, cmdOp = cmdInfo[cmd]
                if not vbucketId in vbmap:
                    log(2, "%s with unknown vbucketId: %s" % (cmdName, vbucketId))
                    sys.exit("ERROR: received %s without checkpoint in vbucket: %s\n" \
                             "Perhaps the server is an older version?"
                             % (cmdName, vbucketId))

                c_s = vbmap[vbucketId]
                checkpointId = c_s[0]
                seq          = c_s[1] = c_s[1] + 1

                eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)

                val = sqlite3.Binary(val)
                op_records.append((vbucketId, checkpointId, seq, cmdOp,key, flg, exp, cas, val))
                result = add_record(c, tap_stmt, (vbucketId, checkpointId, seq, cmdOp,
                                                  key, flg, exp, cas, val))
                if result == False:
                    ## The current backup db file is full
                    try:
                        db.rollback()
                    except sqlite3.Error, e: ## Can't find the better error for rollback failure.
                        log(1, "Insertion transaction was already rollbacked: " + e.args[0])
                    c.close()
                    db.close()
                    split_backup_files.append(next_backup_file)
                    ## Open the new backup db file
                    next_backup_file = get_available_backup_file_name(file)
                    db = create_backup_db(next_backup_file, max_backup_size,
                                                    split_backup, deduplicate)
                    c = db.cursor()
                    ## Insert checkpoint_start record
                    t = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
                    add_record(c, chk_start_stmt, (vbucketId, checkpointId, -1, source, t))
                    db.commit();
                    ## Insert all the records that belong to the rollbacked transaction.
                    for record in op_records:
                        add_record(c, tap_stmt, record)
                update_count = update_count + 1

            elif cmd == memcacheConstants.CMD_TAP_CHECKPOINT_START:
                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)
                checkpoint_id = struct.unpack(">Q", val)
                checkpointStartExists = False
                if vbucketId in vbmap:
                    if vbmap[vbucketId][0] == checkpoint_id[0]:
                        checkpointStartExists = True
                    else:
                        sys.exit("ERROR: CHECKPOINT_START with checkpoint Id %s arrived" \
                                 " before receiving CHECKPOINT_END with checkpoint Id %s"
                                 % (checkpoint_id[0], vbmap[vbucketId][0]))
                if checkpointStartExists == False:
                    vbmap[vbucketId] = [checkpoint_id[0], 0]
                    t = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
                    result = add_record(c, chk_start_stmt,
                                        (vbucketId, checkpoint_id[0], -1, source, t))
                    if result == False:
                        ## The current backup db file is full and closed.
                        c.close()
                        db.close()
                        split_backup_files.append(next_backup_file)
                        ## Open the new backup db file
                        next_backup_file = get_available_backup_file_name(file)
                        db = create_backup_db(next_backup_file, max_backup_size, split_backup)
                        c = db.cursor()
                        ## Insert checkpoint_start record that was rejected above
                        add_record(c, chk_start_stmt, (vbucketId, checkpoint_id[0], -1, source, t))
                    db.commit()

            elif cmd == memcacheConstants.CMD_TAP_CHECKPOINT_END:
                if update_count > 0:
                    db.commit()
                    op_records = []
                    update_count = 0
                checkpoint_id = struct.unpack(">Q", val)
                if not vbucketId in vbmap:
                    sys.exit("ERROR: unmatched checkpoint end: %s vb: %s"
                             % (checkpoint_id[0], vbucketId))

                current_checkpoint_id, seq = vbmap[vbucketId]
                if current_checkpoint_id != checkpoint_id[0]:
                    sys.exit("ERROR: unmatched checkpoint end id: %s vb: %s cp: %s"
                             % (checkpoint_id[0], vbucketId, current_checkpoint_id))

                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)

                del vbmap[vbucketId]

                t = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
                r = c.execute(chk_end_stmt, (t, vbucketId, current_checkpoint_id, source))
                db.commit()
                if r.rowcount != 1:
                   sys.exit("ERROR: unexpected rowcount while updating checkpoint state: "
                            + ",".join([t, vbucketId, checkpointId, source]))
                ## Update the corresponding checkpoint state if exists in the split backup files
                update_checkpoint_state_in_split_backups(split_backup_files, chk_end_stmt,
                                                         (t, vbucketId, current_checkpoint_id, source))

            elif cmd == memcacheConstants.CMD_TAP_OPAQUE:
                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)
                    opaque_opcode = struct.unpack(">I" , val[0:eng_length])
                    if opaque_opcode[0] == memcacheConstants.TAP_OPAQUE_OPEN_CHECKPOINT:
                        if update_count > 0:
                            db.commit()
                            op_records = []
                            update_count = 0
                        ## Clear all checkpoint state records that have "open" state.
                        c.execute(del_chk_start_stmt)
                        db.commit()
                        update_checkpoint_state_in_split_backups(split_backup_files,
                                                                 del_chk_start_stmt, ())
                        log(1, "Incremental backup is currently at the open checkpoint. Exit...")
                        exit(0)

            elif cmd == memcacheConstants.CMD_TAP_CONNECT:
                if update_count > 0:
                    db.commit()
                    op_records = []
                    update_count = 0
                sys.exit("ERROR: TAP_CONNECT error: " + str(key))

            elif cmd == memcacheConstants.CMD_NOOP:
                pass

            else:
                sys.exit("ERROR: unhandled cmd " + str(cmd))

            if update_count == txn_size:
                db.commit()
                op_records = []
                update_count = 0

            if needAck:
                mc._sendMsg(cmd, '', '', opaque,
                            vbucketId=0,
                            fmt=memcacheConstants.RES_PKT_FMT,
                            magic=memcacheConstants.RES_MAGIC_BYTE)

    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        sys.exit("ERROR: " + str(e))
    finally:
        if db:
           db.close()

def readTap(mc):
    ext = ''
    key = ''
    val = ''
    cmd, vbucketId, opaque, cas, keylen, extlen, data = mc._recvMsg()
    if data:
        ext = data[0:extlen]
        key = data[extlen:extlen+keylen]
        val = data[extlen+keylen:]
    return cmd, opaque, cas, vbucketId, key, ext, val

def encodeTAPConnectOpts(opts):
    header = 0
    val = []
    for op in sorted(opts.keys()):
        header |= op
        if op in memcacheConstants.TAP_FLAG_TYPES:
            val.append(struct.pack(memcacheConstants.TAP_FLAG_TYPES[op],
                                   opts[op]))
        else:
            val.append(opts[op])
    return struct.pack(">I", header), ''.join(val)

def parseTapExt(ext):
    if len(ext) == 8:
        flg = exp = 0
        eng_length, flags, ttl = \
            struct.unpack(memcacheConstants.TAP_GENERAL_PKT_FMT, ext)
    else:
        eng_length, flags, ttl, flg, exp = \
            struct.unpack(memcacheConstants.TAP_MUTATION_PKT_FMT, ext)

    needAck = flags & memcacheConstants.TAP_FLAG_ACK

    return eng_length, flags, ttl, flg, exp, needAck

def create_backup_db(backup_file_name, max_backup_size, split_backup, deduplicate):
    db = None
    try:
        db = sqlite3.connect(backup_file_name) # TODO: Revisit isolation level
        db.text_factory = str
        cur = db.execute("pragma user_version").fetchall()[0][0] # File's version.
        if (cur != 0):
            sys.exit("ERROR: unexpected db user version: " + str(cur))
        if deduplicate:
            db.executescript("""
            BEGIN;
            CREATE TABLE cpoint_op
            (vbucket_id integer, cpoint_id integer, seq integer, op text,
            key varchar(250), flg integer, exp integer, cas integer, val blob,
            primary key(vbucket_id, key));
            CREATE TABLE cpoint_state
            (vbucket_id integer, cpoint_id integer, prev_cpoint_id integer, state varchar(1),
            source varchar(250), updated_at text);
            pragma user_version=%s;
            COMMIT;
            """ % (MBB_VERSION))
        else:
            db.executescript("""
            BEGIN;
            CREATE TABLE cpoint_op
            (vbucket_id integer, cpoint_id integer, seq integer, op text,
            key varchar(250), flg integer, exp integer, cas integer, val blob);
            CREATE TABLE cpoint_state
            (vbucket_id integer, cpoint_id integer, prev_cpoint_id integer, state varchar(1),
            source varchar(250), updated_at text);
            pragma user_version=%s;
            COMMIT;
            """ % (MBB_VERSION))

        if split_backup:
            max_backup_size = max_backup_size * 1024 * 1024 # Convert MB to bytes
            db_page_size = db.execute("pragma page_size").fetchone()[0]
            db_max_page_count = max_backup_size / db_page_size
            db.execute("pragma max_page_count=%d" % (db_max_page_count))
    except Exception as e:
        traceback.print_exc(file=sys.stdout)
        sys.exit("ERROR: " + str(e))
    return db

if __name__ == '__main__':
    main()
