#!/usr/bin/env python
# -*-python-*-

"""
Restore tool for Membase.

"""

import itertools
import optparse
import Queue
import socket
import sys
import os
import subprocess
try:
    import sqlite3
except:
    print "mbrestore requires python version 2.6 or greater"
    sys.exit(1)
import thread
import threading
import time
import traceback
import ctypes
import mc_bin_client

DEFAULT_THREADS = 4
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 11211
QUEUE_SIZE = 1000

def connect(opts):
    mc = mc_bin_client.MemcachedClient(opts.host, opts.port)
    if opts.username is not None:
        mc.sasl_auth_plain(opts.username, opts.password)

    return mc

def findCmd(cmdName):
    cmd_dir = os.path.dirname(sys.argv[0])
    possible = []
    for bin_dir in [cmd_dir, os.path.join(cmd_dir, "..", "..", "bin")]:
        possible = possible + [os.path.join(bin_dir, p) for p in [cmdName, cmdName + '.exe']]
    cmdbin = [p for p in possible if os.path.exists(p)][0]
    return cmdbin

def findSqlite():
    return findCmd('sqlite3')

def run_sql(sqlite, fn, sql, more_args=[], logger=sys.stderr):
    args = ['-batch', '-bail']
    cmd = [sqlite] + args + more_args + [fn]
    p = subprocess.Popen(cmd,
                         stdin=subprocess.PIPE,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    (o,e) = p.communicate(sql)
    if p.returncode != 0:
        logger.write("Error running query:  %s\n" % sql)
        logger.write(e)
        sys.exit(1)
    return o

def worker(queue, opts):
    try:
        try:
            mc = connect(opts)
            update = {True: mc.add, False: mc.set}[bool(opts.add)]
            k, flags, exptime, v = queue.get()
            while True:
                try:
                    flags = ctypes.c_uint32(flags).value
                    update(str(k), exptime, socket.ntohl(flags), bytearray(v))
                except mc_bin_client.MemcachedError, error:
                    if error.status == 2:
                        # Already exists
                        queue.task_done()
                        k, flags, exptime, v = queue.get()
                    else:
                        print >> sys.stderr, "Error on key {0!r}: {1}".format(str(k), error)
                        time.sleep(1)
                except:
                    # For other errors, reconnect
                    traceback.print_exc()
                    mc = connect(opts)
                else:
                    queue.task_done()
                    k, flags, exptime, v = queue.get()
        except:
            traceback.print_exc()
    finally:
        thread.interrupt_main()


def db_file_versions(sqlite, db_filenames):
    rv = {}
    for fn in db_filenames:
        rv[fn] = version(sqlite, fn)
    return rv

def version(sqlite, fn):
    return int(run_sql(sqlite, fn, "PRAGMA user_version;").strip())

def main():
    usage = "%prog [opts] db_files (use -h for detailed help)"
    epilog = "Restore keys from the sqlite backing store files from a single node."
    parser = optparse.OptionParser(usage=usage, epilog=epilog)
    parser.add_option("-a", "--add", action="store_true", default=False,
                      help="Use add instead of set to avoid overwriting existing items")
    parser.add_option("-H", "--host", default=DEFAULT_HOST,
                      help="Hostname of moxi server to connect to")
    parser.add_option("-p", "--port", type="int", default=DEFAULT_PORT,
                      help="Port of moxi server to connect to")
    parser.add_option("-u", "--username", help="Bucket username (usually the bucket name) to authenticate to moxi with")
    parser.add_option("-P", "--password", default="",
                      help="Bucket password to authenticate to moxi with")
    parser.add_option("-t", "--threads", type="int", default=DEFAULT_THREADS,
                      help="Number of worker threads")
    opts, db_filenames = parser.parse_args()

    if len(db_filenames) == 0:
        parser.print_usage()
        sys.exit("No filenames provided")
    for fn in db_filenames:
       if not os.path.isfile(fn):
          sys.exit("The file doesn't exist: " + fn)

    sqlite = findSqlite()

    versions = db_file_versions(sqlite, db_filenames)
    if max(versions.values()) < 2:
        sys.exit("Either the metadata db file is not specified\n"
                 "or the backup files need to be upgraded to version 2.\n"
                 "Please use mbdbupgrade for upgrade or contact support.")

    attached_dbs = ["db{0}".format(i) for i in xrange(len(db_filenames))]

    # Backwards compatibility: Deliberately changing out of WAL mode
    # so that older versions of SQLite can access the database files
    for db_name in db_filenames:
        run_sql(sqlite, db_name, "PRAGMA journal_mode=DELETE;")

    # Open the first given filename as the main database
    db = sqlite3.connect(':memory:')
    # Attach the remaining files
    db.executemany("attach ? as ?", zip(db_filenames, attached_dbs))
    # Find all the tables
    table_dbs = {}
    cur = db.cursor()
    for db_name in attached_dbs:
        cur.execute("select name from %s.sqlite_master where type = 'table'" % db_name)
        for (table_name,) in cur:
            table_dbs.setdefault(table_name, []).append(db_name)

    nodata = True
    for kv, dbs in table_dbs.iteritems():
        if 'kv_' in kv:
           nodata = False
           break
    if nodata:
        sys.exit("No data to be restored. Check if db files are correct.")

    # Determine which db the state table is in; will error if there's more than
    # one
    try:
        (state_db,) = table_dbs[u'vbucket_states']
    except ValueError:
        sys.exit("Unable to locate unique vbucket_states table in database files")

    sql = """select k, flags, exptime, v from `{{0}}`.`{{1}}` as kv,
             `{0}`.vbucket_states as vb
             where kv.vbucket = vb.vbid and kv.vb_version = kv.vb_version and
             vb.state like 'active'""".format(state_db)

    queue = Queue.Queue(QUEUE_SIZE)
    connect_args = (opts.host, opts.port, opts.username, opts.password)
    threads = [threading.Thread(target=worker,
                                args=(queue, opts))]
    for thread in threads:
        thread.daemon = True
        thread.start()

    count = 0
    for kv, dbs in table_dbs.iteritems():
        if 'kv_' in kv:
            for db_name in dbs:
                cur.execute(sql.format(db_name, kv))
                for row in cur:
                    queue.put(row)
                    count += 1
                    if count & 1023 == 0:
                        print count
    queue.join()

if __name__ == '__main__':
    main()
