#!/usr/bin/env python

import clitool
import sys
import math
import itertools
import mc_bin_client
import re

MAGIC_CONVERT_RE=re.compile("(\d+)")

BIG_VALUE = 2 ** 60
SMALL_VALUE = - (2 ** 60)

def cmd(f):
    """Decorate a function with code to authenticate based on 1-2
    arguments past the normal number of arguments."""

    def g(*args, **kwargs):
        mc = args[0]
        n = f.func_code.co_argcount
        if len(args) > n:
            print "Too many args, given %s, but expected a maximum of %s"\
                    % (list(args[1:]), n - 1)
            sys.exit(1)

        bucket = kwargs.get('bucketName', None)
        password = kwargs.get('password', None) or ""

        if bucket:
            try:
                mc.sasl_auth_plain(bucket, password)
            except mc_bin_client.MemcachedError:
                print "Authentication error for %s" % bucket
                sys.exit(1)

        if kwargs.get('allBuckets', None):
            buckets = mc.stats('bucket')
            for bucket in buckets.iterkeys():
                print '*' * 78
                print bucket
                print
                mc.bucket_select(bucket)
                f(*args[:n])
        else:
            f(*args[:n])

    return g

def stats_perform(mc, cmd=''):
    try:
        return mc.stats(cmd)
    except:
        print "Stats '%s' are not available from the requested engine." % cmd

def stats_formatter(stats, prefix=" ", cmp=None):
    if stats:
        longest = max((len(x) + 2) for x in stats.keys())
        for stat, val in sorted(stats.items(), cmp=cmp):
            s = stat + ":"
            print "%s%s%s" % (prefix, s.ljust(longest), val)

def time_label(s):
    # -(2**64) -> '-inf'
    # 2**64 -> 'inf'
    # 0 -> '0'
    # 4 -> '4us'
    # 838384 -> '838ms'
    # 8283852 -> '8s'
    if s > BIG_VALUE:
        return 'inf'
    elif s < SMALL_VALUE:
        return '-inf'
    elif s == 0:
        return '0'
    product = 1
    sizes = (('us', 1), ('ms', 1000), ('s', 1000), ('m', 60))
    sizeMap = []
    for l,sz in sizes:
        product = sz * product
        sizeMap.insert(0, (l, product))
    lbl, factor = itertools.dropwhile(lambda x: x[1] > s, sizeMap).next()
    return "%d%s" % (s / factor, lbl)

def sec_label(s):
    return time_label(s * 1000000)

def size_label(s):
    if s == 0:
        return "0"
    sizes=['', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB']
    e = math.floor(math.log(abs(s), 1024))
    suffix = sizes[int(e)]
    return "%d%s" % (s/(1024 ** math.floor(e)), suffix)

@cmd

def histograms(mc, raw_stats):
    # Try to figure out the terminal width.  If we can't, 79 is good
    def termWidth():
        try:
            import fcntl, termios, struct
            h, w, hp, wp = struct.unpack('HHHH',
                                         fcntl.ioctl(0, termios.TIOCGWINSZ,
                                                     struct.pack('HHHH', 0, 0, 0, 0)))
            return w
        except:
            return 79

    special_labels = {'klogPadding': size_label,
                      'item_alloc_sizes': size_label,
                      'paged_out_time': sec_label}

    histodata = {}
    for k, v in raw_stats.items():
        # Parse out a data point
        ka = k.split('_')
        k = '_'.join(ka[0:-1])
        kstart, kend = [int(x) for x in ka[-1].split(',')]

        # Create a label for the data point
        label_func = time_label
        if k.endswith("Size") or k.endswith("Seek"):
            label_func = size_label
        elif k in special_labels:
            label_func = special_labels[k]

        label = "%s - %s" % (label_func(kstart), label_func(kend))
        
        if not k in histodata:
            histodata[k] = []
        histodata[k].append({'start' : int(kstart),
                             'end'   : int(kend),
                             'label' : label,
                             'lb_fun': label_func,
                             'value' : int(v)})
    
    for name, data_points in histodata.items():
        max_label_len = max([len(stat['label']) for stat in data_points])
        widestval = len(str(max([stat['value'] for stat in data_points])))
        total = sum([stat['value'] for stat in data_points])
        avg = sum([((x['end'] - x['start']) * x['value']) for x in data_points]) / total

        print " %s (%d total)" % (name, total)

        total_seen = 0
        for dp in sorted(data_points, key=lambda x: x['start']):
            total_seen += dp['value']
            pcnt = (total_seen * 100.0) / total
            toprint  = "    %s : (%6.02f%%) %s" % \
                       (dp['label'].ljust(max_label_len), pcnt,
                       str(dp['value']).rjust(widestval))

            remaining = termWidth() - len(toprint) - 2
            lpcnt = float(dp['value']) / total
            print "%s %s" % (toprint, '#' * int(lpcnt * remaining))
        print "    %s : (%s)" % ("Avg".ljust(max_label_len),
                                dp['lb_fun'](avg).rjust(7))

@cmd
def stats_key(mc, key, vb):
    cmd = "key %s %s" % (key, str(vb))
    vbs = stats_perform(mc, cmd)
    if vbs:
        print "stats for key", key
        stats_formatter(vbs)

@cmd
def stats_vkey(mc, key, vb):
    cmd = "vkey %s %s" % (key, str(vb))
    try:
        vbs = mc.stats(cmd)
    except mc_bin_client.MemcachedError, e:
        print e.message
        sys.exit()
    except Exception, e:
        print "Stats '%s' are not available from the requested engine." % cmd
        sys.exit()

    if vbs:
        print "verification for key", key
        stats_formatter(vbs)

@cmd
def stats_all(mc):
    stats_formatter(stats_perform(mc))

def time_label(s):
    # -(2**64) -> '-inf'
    # 2**64 -> 'inf'
    # 0 -> '0'
    # 4 -> '4us'
    # 838384 -> '838ms'
    # 8283852 -> '8s'
    if s > BIG_VALUE:
        return 'inf'
    elif s < SMALL_VALUE:
        return '-inf'
    elif s == 0:
        return '0'
    product = 1
    sizes = (('us', 1), ('ms', 1000), ('s', 1000), ('m', 60))
    sizeMap = []
    for l,sz in sizes:
        product = sz * product
        sizeMap.insert(0, (l, product))
    lbl, factor = itertools.dropwhile(lambda x: x[1] > s, sizeMap).next()
    return "%d%s" % (s / factor, lbl)

@cmd
def stats_timings(mc):
    h = stats_perform(mc, 'timings')
    if h:
        histograms(mc, h)

@cmd
def stats_tap(mc):
    stats_formatter(stats_perform(mc, 'tap'))

@cmd
def stats_tapagg(mc):
    stats_formatter(stats_perform(mc, 'tapagg _'))

@cmd
def stats_checkpoint(mc, vb=-1):
    if vb == -1:
        cmd = 'checkpoint'
    else:
        cmd = "checkpoint %s" % (str(vb))
    stats_formatter(stats_perform(mc, cmd))

@cmd
def stats_allocator(mc):
    print stats_perform(mc, 'allocator')['detailed']

@cmd
def stats_slabs(mc):
    stats_formatter(stats_perform(mc, 'slabs'))

@cmd
def stats_items(mc):
    stats_formatter(stats_perform(mc, 'items'))

@cmd
def stats_vbucket(mc):
    stats_formatter(stats_perform(mc, 'vbucket'))

@cmd
def stats_vbucket_details(mc):
    stats_formatter(stats_perform(mc, 'vbucket-details'))

@cmd
def stats_prev_vbucket(mc):
    stats_formatter(stats_perform(mc, 'prev-vbucket'))

@cmd
def stats_memory(mc):
    stats_formatter(stats_perform(mc, 'memory'))

@cmd
def stats_config(mc):
    stats_formatter(stats_perform(mc, 'config'))

@cmd
def stats_warmup(mc):
    stats_formatter(stats_perform(mc, 'warmup'))

@cmd
def stats_klog(mc):
    stats_formatter(stats_perform(mc, 'klog'))

@cmd
def stats_info(mc):
    stats_formatter(stats_perform(mc, 'info'))

@cmd
def stats_raw(mc, arg):
    stats_formatter(stats_perform(mc,arg))

@cmd
def stats_kvstore(mc):
    stats_formatter(stats_perform(mc, 'kvstore'))

@cmd
def stats_kvtimings(mc):
    h = stats_perform(mc, 'kvtimings')
    if h:
        histograms(mc, h)

def avg(s):
    return sum(s) / len(s)

def _maybeInt(x):
    try:
        return int(x)
    except:
        return x

def _magicConvert(s):
    return [_maybeInt(x) for x in MAGIC_CONVERT_RE.split(s)]

def magic_cmp(a, b):
    am = _magicConvert(a[0])
    bm = _magicConvert(b[0])
    return cmp(am, bm)

@cmd
def stats_hash(mc, with_detail=None):
    h = stats_perform(mc,'hash')
    if not h:
        return
    with_detail = with_detail == 'detail'

    mins = []
    maxes = []
    counts = []
    for k,v in h.items():
        if 'max_dep' in k:
            maxes.append(int(v))
        if 'min_dep' in k:
            mins.append(int(v))
        if 'counted' in k:
            counts.append(int(v))
        if ':histo' in k:
            vb, kbucket = k.split(':')
            skey = 'summary:' + kbucket
            h[skey] = int(v) + h.get(skey, 0)

    h['avg_min'] = avg(mins)
    h['avg_max'] = avg(maxes)
    h['avg_count'] = avg(counts)
    h['min_count'] = min(counts)
    h['max_count'] = max(counts)
    h['total_counts'] = sum(counts)
    h['largest_min'] = max(mins)
    h['largest_max'] = max(maxes)

    toDisplay = h
    if not with_detail:
        toDisplay = {}
        for k in h:
            if 'vb_' not in k:
                toDisplay[k] = h[k]

    stats_formatter(toDisplay, cmp=magic_cmp)

@cmd
def stats_dispatcher(mc, with_logs='no'):
    with_logs = with_logs == 'logs'
    sorig = stats_perform(mc,'dispatcher')
    if not sorig:
        return
    s = {}
    logs = {}
    slowlogs = {}
    for k,v in sorig.items():
        ak = tuple(k.split(':'))
        if ak[-1] == 'runtime':
            v = time_label(int(v))

        dispatcher = ak[0]

        for h in [logs, slowlogs]:
            if dispatcher not in h:
                h[dispatcher] = {}

        if ak[0] not in s:
            s[dispatcher] = {}

        if ak[1] in ['log', 'slow']:
            offset = int(ak[2])
            field = ak[3]
            h = {'log': logs, 'slow': slowlogs}[ak[1]]
            if offset not in h[dispatcher]:
                h[dispatcher][offset] = {}
            h[dispatcher][offset][field] = v
        else:
            field = ak[1]
            s[dispatcher][field] = v

    for dispatcher in sorted(s):
        print " %s" % dispatcher
        stats_formatter(s[dispatcher], "     ")
        for l,h in [('Slow jobs', slowlogs), ('Recent jobs', logs)]:
            if with_logs and h[dispatcher]:
                print "     %s:" % l
                for offset, fields in sorted(h[dispatcher].items()):
                    stats_formatter(fields, "        ")
                    print "        ---------"

@cmd
def reset(mc):
    stats_perform(mc, 'reset')

def main():
    c = clitool.CliTool()

    c.addCommand('all', stats_all, 'all')
    c.addCommand('allocator', stats_allocator, 'allocator')
    c.addCommand('checkpoint', stats_checkpoint, 'checkpoint [vbid]')
    c.addCommand('config', stats_config, 'config')
    c.addCommand('dispatcher', stats_dispatcher, 'dispatcher [logs]')
    c.addCommand('hash', stats_hash, 'hash [detail]')
    c.addCommand('items', stats_items, 'items (memcached bucket only)')
    c.addCommand('key', stats_key, 'key keyname vbid')
    c.addCommand('klog', stats_klog, 'klog')
    c.addCommand('kvstore', stats_kvstore, 'kvstore')
    c.addCommand('kvtimings', stats_kvtimings, 'kvtimings')
    c.addCommand('memory', stats_memory, 'memory')
    c.addCommand('prev-vbucket', stats_prev_vbucket, 'prev-vbucket')
    c.addCommand('raw', stats_raw, 'raw argument')
    c.addCommand('reset', reset, 'reset')
    c.addCommand('slabs', stats_slabs, 'slabs (memcached bucket only)')
    c.addCommand('tap', stats_tap, 'tap')
    c.addCommand('tapagg', stats_tapagg, 'tapagg')
    c.addCommand('timings', stats_timings, 'timings')
    c.addCommand('vbucket', stats_vbucket, 'vbucket')
    c.addCommand('vbucket-details', stats_vbucket_details, 'vbucket-details')
    c.addCommand('vkey', stats_vkey, 'vkey keyname vbid')
    c.addCommand('warmup', stats_warmup, 'warmup')
    c.addFlag('-a', 'allBuckets', 'iterate over all buckets (requires admin u/p)')
    c.addOption('-b', 'bucketName', 'the bucket to get stats from (Default: default)')
    c.addOption('-p', 'password', 'the password for the bucket if one exists')

    c.execute()

if __name__ == '__main__':
    main()
