#!/usr/bin/env python

import sys
import os
import glob
import subprocess
import shutil

TARGET_VERSION=2

# Backported any to older versions of python
try:
    any
except NameError:
    def any(a):
        for i in a:
            if i:
                return True
        return False

try:
    all
except NameError:
    def all(iterable):
        for element in iterable:
            if not element:
                return False
        return True

def run_sql(sqlite, fn, sql, more_args=[], logger=sys.stderr):
    args = ['-batch', '-bail']
    cmd = [sqlite] + args + more_args + [fn]
    p = subprocess.Popen(cmd,
                         stdin=subprocess.PIPE,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    (o,e) = p.communicate(sql)
    if p.returncode != 0:
        logger.write("Error running query:  %s\n" % sql)
        logger.write(e)
        sys.exit(1)
    return o

def version(sqlite, fn):
    return int(run_sql(sqlite, fn, "pragma user_version;").strip())

def count_kv_vb_tables(sqlite, fn):
    return int(run_sql(sqlite, fn,
                       'select count(*) from sqlite_master '\
                       'where name like "kv_%";').strip())

def findCmd(cmdName):
    cmd_dir = os.path.dirname(sys.argv[0])
    possible = []
    for bin_dir in [cmd_dir, os.path.join(cmd_dir, "..", "..", "bin")]:
        possible = possible + [os.path.join(bin_dir, p) for p in [cmdName, cmdName + '.exe']]
    cmdbin = [p for p in possible if os.path.exists(p)][0]
    return cmdbin

def findSqlite():
    return findCmd('sqlite3')

def usage(command, msg):
    e = """Error:  %s

Usage:  %s /path/to/srcdb [...] /path/to/dest

src is the path to the main database.  You can specify more than one.
dest can be either a new main database filename, or a directory
(the old DB name will be used).

If more than one src db is specified, the destination *must* be a
directory.""" % (msg, command)
    sys.exit(e)

class NullLogger(object):

    def write(self, x):
        pass

def updateVBStatesTableSchema(sqlite, src):
    try:
        run_sql(sqlite, src, 'alter table vbucket_states add column vb_version;',
                logger=NullLogger())
    except SystemExit:
        pass # Already had the column
    try:
        run_sql(sqlite, src, 'alter table vbucket_states add column checkpoint_id;',
                logger=NullLogger())
    except SystemExit:
        pass # Already had the column

def updateKVTableSchema(sqlite, src):
    for d in glob.glob(src + '*.mb') + glob.glob(src + '*.sqlite'):
        if os.path.exists(d):
            try:
                run_sql(sqlite, d, 'alter table kv add column vb_version;',
                        logger=NullLogger())
            except SystemExit:
                pass # Already had the column

def ensureNewColumns(sqlite, src):
    updateVBStatesTableSchema(sqlite, src)
    updateKVTableSchema(sqlite, src)

def findTarget(src, dest):
    if os.path.isdir(dest):
        dest = os.path.join(dest, os.path.basename(src))
    return dest

def findTargetCollisions(srcs, dest):
    rv = []
    for fn in (findTarget(s, dest) for s in srcs):
        if os.path.exists(fn):
            rv.append(fn)
        for shard in glob.glob(fn + '*.mb'):
            if os.path.exists(shard):
                rv.append(shard)
    return rv

def doset(sqlite, src, dest):
    v = version(sqlite, src)
    print 'Source version from "%s" is %d' % (src, v)

    args = [findCmd('mbdbconvert')]
    if v < 2:
        args.append("--remove-crlf")
    if os.path.exists(src + '-0.mb'):
        args.append('--src-pattern=%d/%b-%i.mb')
        if count_kv_vb_tables(sqlite, src + "-0.mb") > 0:
            args.append("--src-strategy=multiMTVBDB")
    args.append(src)
    args.append(findTarget(src, dest))
    args.append("--report-every=2000")
    mydir = os.path.dirname(sys.argv[0])
    if os.name == 'nt':
        init_file = os.path.join(mydir, '..', 'etc', 'membase', 'init.sql')
    else:
        init_file = os.path.join(mydir, '..', '..', 'etc', 'membase', 'init.sql')
    args.append("--init-file=" + init_file)

    print " ".join(args)

    ensureNewColumns(sqlite, src)

    returncode = subprocess.call(args)
    if returncode != 0:
        sys.stderr.write("Error running convert, exit code %d (%s)\n" %
                         (returncode, ' '.join(args)))
        sys.exit(1)

def main():
    command = sys.argv[0]
    args = sys.argv[1:]
    try:
        dest = args.pop()
        srcs = args
        if not srcs:
            usage(command, "No destination specified")
        if not all(os.path.isfile(p) for p in srcs):
            usage(command, "Source must point to existing main DB files")
        if len(srcs) > 1 and not os.path.isdir(dest):
            usage(command,
                  "Multiple sources, but destination is not a directory.")
    except IndexError:
        usage(command, "Too few arguments")

    if os.name == 'nt':
        mydir = os.path.dirname(sys.argv[0])
        memcache_dir = os.path.join(mydir, '..', '..', 'memcached')
        erlang_dir = os.path.join(mydir, 'erlang', 'bin')
        path = [mydir, memcache_dir, erlang_dir, os.environ['PATH']]
        os.environ['PATH'] = ';'.join(path)

    sqlite = findSqlite()

    collisions = findTargetCollisions(srcs, dest)
    if collisions:
        sys.stderr.write("Would overwrite the following files:\n\t")
        sys.stderr.write('\n\t'.join(collisions) + '\n')
        sys.exit(1)

    for src in srcs:
        doset(sqlite, src, dest)

if __name__ == '__main__':
    main()
