#!/usr/bin/env python

import os
import sys
import glob
import getopt
import exceptions
import datetime
import select
import string
import struct
import time
import traceback

import mc_bin_client
import memcacheConstants
import util

try:
    import sqlite3
except:
    sys.exit("ERROR: %s requires python version 2.6 or greater" %
              (os.path.basename(sys.argv[0])))

MBB_VERSION = "1" # sqlite pragma user version.

DEFAULT_PORT      = "11210"
DEFAULT_HOST_PORT = ["127.0.0.1", DEFAULT_PORT]
DEFAULT_FILE      = "./backup-%.mbb"
DEFAULT_TXN_SIZE  = 5000

def usage(err=0):
    print >> sys.stderr, """
Usage: %s [-h %s[:%s]] [-o %s] [-T num_secs] [-r] [-c] [-t transaction_size] [-v] tap_name|db_file
""" % (os.path.basename(sys.argv[0]),
       DEFAULT_HOST_PORT[0], DEFAULT_HOST_PORT[1], DEFAULT_FILE)
    sys.exit(err)

def parse_args(args):
    host_port = DEFAULT_HOST_PORT
    file      = DEFAULT_FILE
    timeout   = 0
    verbosity = 0
    only_name = False
    only_check = False
    txn_size  = DEFAULT_TXN_SIZE

    try:
        opts, args = getopt.getopt(args, 'h:o:T:t:rcv', ['help'])
    except getopt.GetoptError, e:
        usage(e.msg)

    for (o, a) in opts:
        if o == '--help':
            usage()
        elif o == '-h':
            host_port = a.split(':')
            if len(host_port) < 2:
                host_port = [a, DEFAULT_PORT]
        elif o == '-o':
            file = a
        elif o == '-T':
            timeout = float(a)
        elif o == '-r':
            only_name = True # Only register the name and exit.
        elif o == '-c':
            only_check = True # Only check the incremental backup file by showing the list of checkpoints.
        elif o == '-t':
            txn_size = int(a)
        elif o == '-v':
            verbosity = verbosity + 1
        else:
            usage("unknown option - " + o)

    if not args:
        usage("missing the registered client name or the incremental backup file name")
    if len(args) != 1:
        usage("incorrect number of arguments - only one registered name or backup file name needed")

    return host_port, file, args[0], timeout, only_name, only_check, txn_size, verbosity

def log(level, *args):
    global verbosity
    if level < verbosity:
       s = ", ".join(list(args))
       print string.rjust(s, (level * 2) + len(s))

def main():
    global verbosity

    host_port, file, name, timeout, only_name, only_check, txn_size, verbosity = parse_args(sys.argv[1:])
    log(1, "host_port = " + ':'.join(host_port))
    log(1, "verbosity = " + str(verbosity))
    log(1, "only_name = " + str(only_name))
    log(1, "only_check = " + str(only_check))
    log(1, "timeout = " + str(timeout))
    log(1, "transaction_size = " + str(txn_size))
    log(1, "file = " + file)
    if only_check:
        log(1, "incremental_backup_file = " + name)
    else:
        log(1, "registered_client_name = " + name)

    mc = None
    db = None

    try:
        if only_check:
            check_incremental_backup_file(name)
            sys.exit(0)

        mc = mc_bin_client.MemcachedClient(host_port[0], int(host_port[1]))

        ext, val = encodeTAPConnectOpts({
          memcacheConstants.TAP_FLAG_CHECKPOINT: '',
          memcacheConstants.TAP_FLAG_SUPPORT_ACK: '',
          memcacheConstants.TAP_FLAG_REGISTERED_CLIENT: 0x01, # "value > 0" means "closed checkpoints only"
        })

        mc._sendCmd(memcacheConstants.CMD_TAP_CONNECT, name, val, 0, ext)

        if only_name:
            cmd, opaque, cas, vbucketId, key, ext, val = readTap(mc)
            if cmd == memcacheConstants.CMD_TAP_OPAQUE:
                sys.exit(0);
            sys.exit("ERROR: could not register name: " + name)

        file = util.expand_file_pattern(file)
        log(1, "file actual = " + file)
        if os.path.exists(file):
           sys.exit("ERROR: file exists already: " + file)

        db = sqlite3.connect(file) # TODO: Revisit isolation level
        db.text_factory = str
        createSchema(db)

        loop(mc, db, timeout, txn_size, ':'.join(host_port) + '-' + name)

    except NameError as ne:
        sys.exit("ERROR: " + str(ne))
    except Exception as e:
        if verbosity > 1:
            traceback.print_exc(file=sys.stdout)
            print e
        sys.exit("ERROR: " + str(e))
    finally:
        if mc:
           mc.close()
        if db:
           db.close()

def check_incremental_backup_file(file_name):
    db = None
    try:
        if os.path.exists(file_name) == False:
            sys.exit("ERROR: file does not exist: " + file_name)

        db = sqlite3.connect(file_name)
        db.text_factory = str
        c = db.cursor()
        c.execute("select vbucket_id, cpoint_id, state, updated_at from cpoint_state " +
                  "order by vbucket_id asc, updated_at asc")
        print "\n vbucket_id   checkpoint_id   state    updated_at"
        print "------------------------------------------------------"
        for row in c:
            updated = time.strptime(str(row[3]), "%Y%m%d%H%M%S")
            updated_at = datetime.datetime.fromtimestamp(time.mktime(updated)).strftime("%Y/%m/%d %H:%M:%S")
            print " %s   %s   %s   %s" % (str(row[0]).center(10), str(row[1]).center(13), str(row[2]).center(6), updated_at.center(19))

    except Exception as e:
        if verbosity > 1:
            traceback.print_exc(file=sys.stdout)
            print e
        sys.exit("ERROR: " + str(e))
    finally:
        if db:
            db.close()

def loop(mc, db, timeout, txn_size, source):
    vbmap = {} # Key is vbucketId, value is [checkpointId, seq].

    cmdInfo = {
        memcacheConstants.CMD_TAP_MUTATION: ('mutation', 'm'),
        memcacheConstants.CMD_TAP_DELETE: ('delete', 'd'),
        memcacheConstants.CMD_TAP_FLUSH: ('flush', 'f'),
    }

    try:
        sinput = [mc.s]
        c = db.cursor()
        update_count = 0

        while True:
            if timeout > 0:
                iready, oready, eready = select.select(sinput, [], [], timeout)
                if (not iready) and (not oready) and (not eready):
                    log(1, "EXIT: timeout after " + str(timeout) + " seconds of inactivity")
                    sys.exit(0)

            cmd, opaque, cas, vbucketId, key, ext, val = readTap(mc)
            log(2, "got " + str(cmd) + " k:" + key + " vlen:" + str(len(val))
                    + " elen:" + str(len(ext))
                    + " vbid:" + str(vbucketId))
            if len(val) > 0 and len(val) < 64:
                log(2, "  val: <<" + str(val) + ">>")

            needAck = False

            if (cmd == memcacheConstants.CMD_TAP_MUTATION or
                cmd == memcacheConstants.CMD_TAP_DELETE or
                cmd == memcacheConstants.CMD_TAP_FLUSH):
                cmdName, cmdOp = cmdInfo[cmd]
                if not vbucketId in vbmap:
                    log(2, "%s with unknown vbucketId: %s" % (cmdName, vbucketId))
                    sys.exit("ERROR: received %s without checkpoint in vbucket: %s\n" \
                             "Perhaps the server is an older version?"
                             % (cmdName, vbucketId))

                c_s = vbmap[vbucketId]
                checkpointId = c_s[0]
                seq          = c_s[1] = c_s[1] + 1

                eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)

                if (cmd == memcacheConstants.CMD_TAP_MUTATION):
                    s = "INSERT into cpoint_op" \
                        "(vbucket_id, cpoint_id, seq, op, key, flg, exp, cas, val)" \
                        " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"
                    c.execute(s, (vbucketId, checkpointId, seq, cmdOp,
                                   key, flg, exp, cas, sqlite3.Binary(val)))
                else:
                    s = "INSERT into cpoint_op" \
                        "(vbucket_id, cpoint_id, seq, op, key, cas, val)" \
                        " VALUES (?, ?, ?, ?, ?, ?, ?)"
                    c.execute(s, (vbucketId, checkpointId, seq, cmdOp,
                                   key, cas, sqlite3.Binary(val)))

                update_count = update_count + 1
                needAck = flags & memcacheConstants.TAP_FLAG_ACK

            elif cmd == memcacheConstants.CMD_TAP_CHECKPOINT_START:
                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)
                checkpoint_id = struct.unpack(">Q", val)
                checkpointStartExists = False
                if vbucketId in vbmap:
                    if vbmap[vbucketId][0] == checkpoint_id[0]:
                        checkpointStartExists = True
                    else:
                        sys.exit("ERROR: CHECKPOINT_START with checkpoint Id %s arrived" \
                                 " before receiving CHECKPOINT_END with checkpoint Id $s"
                                 % (checkpoint_id[0], vbmap[vbucketId][0]))
                if checkpointStartExists == False:
                    vbmap[vbucketId] = [checkpoint_id[0], 0]
                    t = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
                    s = "INSERT into cpoint_state" \
                        "(vbucket_id, cpoint_id, prev_cpoint_id, state, source, updated_at)" \
                        " VALUES (?, ?, ?, \"open\", ?, ?)"
                    c.execute(s, (vbucketId, checkpoint_id[0], -1, source, t))
                    db.commit()

            elif cmd == memcacheConstants.CMD_TAP_CHECKPOINT_END:
                db.commit()
                update_count = 0
                checkpoint_id = struct.unpack(">Q", val)
                if not vbucketId in vbmap:
                    sys.exit("ERROR: unmatched checkpoint end: %s vb: %s"
                             % (checkpoint_id[0], vbucketId))

                current_checkpoint_id, seq = vbmap[vbucketId]
                if current_checkpoint_id != checkpoint_id[0]:
                    sys.exit("ERROR: unmatched checkpoint end id: %s vb: %s cp: %s"
                             % (checkpoint_id[0], vbucketId, current_checkpoint_id))

                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)

                del vbmap[vbucketId]

                t = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
                s = "UPDATE cpoint_state" \
                    " SET state=\"closed\", updated_at=?" \
                    " WHERE vbucket_id=? AND cpoint_id=? AND state=\"open\" AND source=?"
                r = c.execute(s, (t, vbucketId, current_checkpoint_id, source))
                db.commit()
                if r.rowcount != 1:
                   sys.exit("ERROR: unexpected rowcount during update: "
                            + ",".join([t, vbucketId, checkpointId, source]))

            elif cmd == memcacheConstants.CMD_TAP_OPAQUE:
                if len(ext) > 0:
                    eng_length, flags, ttl, flg, exp, needAck = parseTapExt(ext)
                    opaque_opcode = struct.unpack(">I" , val[0:eng_length])
                    if opaque_opcode[0] == memcacheConstants.TAP_OPAQUE_OPEN_CHECKPOINT:
                        if update_count > 0:
                            db.commit()
                        log(1, "Incremental backup is currently at the open checkpoint. Exit...")
                        exit(0)

            elif cmd == memcacheConstants.CMD_TAP_CONNECT:
                if update_count > 0:
                    db.commit()
                sys.exit("ERROR: TAP_CONNECT error: " + str(key))

            elif cmd == memcacheConstants.CMD_NOOP:
                pass

            else:
                sys.exit("ERROR: unhandled cmd " + str(cmd))

            if update_count == txn_size:
                db.commit()
                update_count = 0

            if needAck:
                mc._sendMsg(cmd, '', '', opaque,
                            vbucketId=0,
                            fmt=memcacheConstants.RES_PKT_FMT,
                            magic=memcacheConstants.RES_MAGIC_BYTE)

    except exceptions.EOFError:
        pass

def readTap(mc):
    ext = ''
    key = ''
    val = ''
    cmd, vbucketId, opaque, cas, keylen, extlen, data = mc._recvMsg()
    if data:
        ext = data[0:extlen]
        key = data[extlen:extlen+keylen]
        val = data[extlen+keylen:]
    return cmd, opaque, cas, vbucketId, key, ext, val

def encodeTAPConnectOpts(opts):
    header = 0
    val = []
    for op in sorted(opts.keys()):
        header |= op
        if op in memcacheConstants.TAP_FLAG_TYPES:
            val.append(struct.pack(memcacheConstants.TAP_FLAG_TYPES[op],
                                   opts[op]))
        else:
            val.append(opts[op])
    return struct.pack(">I", header), ''.join(val)

def parseTapExt(ext):
    if len(ext) == 8:
        flg = exp = 0
        eng_length, flags, ttl = \
            struct.unpack(memcacheConstants.TAP_GENERAL_PKT_FMT, ext)
    else:
        eng_length, flags, ttl, flg, exp = \
            struct.unpack(memcacheConstants.TAP_MUTATION_PKT_FMT, ext)

    needAck = flags & memcacheConstants.TAP_FLAG_ACK

    return eng_length, flags, ttl, flg, exp, needAck

def createSchema(db):
    cur = db.execute("pragma user_version").fetchall()[0][0] # File's version.
    if (cur != 0):
        sys.exit("ERROR: unexpected db user version: " + str(cur))

    db.executescript("""
BEGIN;
CREATE TABLE cpoint_op
  (vbucket_id integer, cpoint_id integer, seq integer, op text,
   key varchar(250), flg integer, exp integer, cas integer, val blob);
CREATE TABLE cpoint_state
  (vbucket_id integer, cpoint_id integer, prev_cpoint_id integer, state varchar(1),
   source varchar(250), updated_at text);
pragma user_version=%s;
COMMIT;
""" % (MBB_VERSION))

if __name__ == '__main__':
    main()
