epcim
6/16/2016 - 3:48 PM

Cluster execution tool

Cluster execution tool

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Tool for commands execution over clusters
"""

import os, sys, logging, urllib
import argparse
import threading
import paramiko
import time
import socket
import base64
import signal
import getpass

logging.basicConfig(level=logging.WARN, format='%(levelname)s: %(message)s')
lg = logging.getLogger()

sshPool   = {}
sshFailed = []
sshDone   = []
sshHosts  = []

threadLimiter = None

exitcode = 0

def main():
    """
        main entrance
    """

    global exitcode, apiUrl

    # Catch SIGINFO if supported
    if hasattr(signal, 'SIGINFO'):
        signal.signal(signal.SIGINFO, siginfo_handler)
    if hasattr(signal, 'SIGUSR1'):
        signal.signal(signal.SIGUSR1, siginfo_handler)

    parser = argparse.ArgumentParser(description='Execute command on cluster', add_help=False)
    # Required
    group_req = parser.add_argument_group('Required arguments')
    group_req.add_argument('command', help="Command to be executed. Use -- after arguments accepting multiple values.", nargs='?')

    # Optional
    group_opt = parser.add_argument_group('Optional arguments')
    group_opt.add_argument('--system-ssh', dest='system_ssh', action='store_true', help="Call system SSH client instead of Pythonish (worse escaping, obsolete)")
    group_opt.add_argument('--serial', '--no-parallel', dest='serial', action='store_true', help="Execute commands on hosts one-by-one")
    group_opt.add_argument('-t', '--threads', dest='threads', type=int, default=120, help="Execute commands on hosts in x threads (default 120)")
    group_opt.add_argument('-f', '--file', dest='file', help="Read list of nodes from file (ignores cluster)")
    group_opt.add_argument('-m', '--machines', dest='machines', nargs='+', default=[], help="List of machines to operate on")
    group_opt.add_argument('-u', '--user', dest='sshUser', help="SSH user to connect with, defaults to the current user", default=getpass.getuser())
    group_opt.add_argument('-K', '--key-file', dest='sshKeyFile', help="SSH user key file to connect with")
    group_opt.add_argument('-d', '--domain', dest='domain', help="Location domain to use")
    group_opt.add_argument('-e', '--exitcode', dest='exitcode', action='store_true', help="Exit with non-zero exit code if command return non-zero exit code")

    # Output switchers
    group_out = parser.add_argument_group('Output switchers')
    group_out.add_argument('--debug', dest='debug', action='store_true')

    # Action switchers
    group_act = parser.add_argument_group('Action switchers')
    group_act.add_argument('-h', '--help', dest='help', action='store_true', help="Show this help")
    group_act.add_argument('-L', '--list-nodes', dest='list_nodes', action='store_true', help="List nodes where command would be executed")
    group_act.add_argument('-I', '--interactive', '--shell', dest='interactive', action='store_true', help="Run in interactive mode, same as if command is -")
    group_act.add_argument('-U', '--upload', dest='upload', help="Upload file to [command] on nodes")

    args = parser.parse_args()

    if args.debug:
        lg.setLevel(logging.DEBUG)

    if args.help:
        parser.print_help()
        sys.exit(0)

    if args.interactive:
        args.command = '-'

    if args.threads != None and args.threads == 0:
        print base64.b64decode('ICAgICAgIF8gICAgIF8KICAgICAgIFxgXCAvYC8KICAgICAgICBcIFYgLyAgICAgICAgICAgICAgIAogICAgICAgIC8uIC5cICAgICAgIAogICAgICAgPVwgVCAvPSAgICAgICAgICAgICAgICAgIAogICAgICAgIC8gXiBcICAgICAKICAgICAgIC9cXCAvL1wKICAgICBfX1wgIiAiIC9fXyAgICAgICAgICAgCiAgICAoX19fXy9eXF9fX18pCiAgWW91J3JlIGEgVGVhcG90IQo=')
        sys.exit(1)

    # Set thread limiter
    global threadLimiter
    if args.threads:
        threadLimiter = threading.BoundedSemaphore(args.threads)
    else:
        if args.serial:
            threadLimiter = threading.BoundedSemaphore(1)
        else:
            # Default limit is 120 threads at once
            threadLimiter = threading.BoundedSemaphore(120)

    # Can't read from stdin for multiple options
    if args.file == '-' and args.command == '-':
        lg.error("Can't read nodes and command from stdin, try to use -m option instead of -f")
        sys.exit(1)

    lg.debug("Command: %s" % args.command)

    if args.file:
        if args.file != '-':
            try:
                machines = open(args.file, 'r')
            except:
                lg.error("Can't open file %s" % args.file)
                sys.exit(1)
        else:
            machines = sys.stdin

        m = []
        for machine in machines.readlines():
            m.append(machine.replace('\n', ''))
        machines = m
    elif args.machines:
        machines = args.machines
    else:
        raise RuntimeError("You need to submit list of hosts to connect to")

    global sshHosts
    sshHosts = machines

    # Interactive mode
    if args.command == '-':
        import readline
        readline.parse_and_bind('tab: complete')
        readline.parse_and_bind('set editing-mode vi')
        while True:
            try:
                args.command = raw_input("$> ")
            except (KeyboardInterrupt, SystemExit, EOFError):
                lg.debug("Interrupted")
                sshCleanup()
                print ''
                sys.exit(0)

            if args.command in ['exit', 'quit']:
                sshCleanup()
                sys.exit(0)

            if args.command:
                # Do the job
                pool = run(machines, args)

                # Wait till all threads are done
                try:
                    alive = len(pool)
                    while alive > 0:
                        alive = len(pool)
                        lg.debug("Waiting for %i threads" % alive)
                        for thread in pool:
                            if not thread.isAlive():
                                alive -= 1
                        time.sleep(0.5)
                except (KeyboardInterrupt, SystemExit):
                    lg.debug("Received keyboard interrupt. Cleaning threads and exitting.")
                    for thread in pool:
                        if thread.isAlive():
                            lg.debug("Killing thread %s" % thread.getName())
                            try:
                                thread._Thread__stop()
                            except:
                                lg.error("Thread %s cannot be terminated" % thread.getName())
                    sshCleanup()
                    sys.exit(1)

        sshCleanup()
        sys.exit(0)

    # Do the job (normal mode)
    pool = run(machines, args)

    # Wait till all threads are done
    try:
        alive = len(pool)
        while alive > 0:
            alive = len(pool)
            lg.debug("Waiting for %i threads" % alive)
            for thread in pool:
                if not thread.isAlive():
                    alive -= 1
            time.sleep(0.5)
    except (KeyboardInterrupt, SystemExit):
        lg.debug("Received keyboard interrupt. Cleaning threads and exitting.")
        for thread in pool:
            if thread.isAlive():
                lg.debug("Killing thread %s" % thread.getName())
                try:
                    thread._Thread__stop()
                except Exception as e:
                    lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e))
    finally:
        sshCleanup()

    if sshFailed:
        lg.error("Failed connections (%s/%s): %s" % (len(sshFailed), len(sshHosts), ','.join(sshFailed)))
        sys.exit(1)

    if args.exitcode:
        sys.exit(exitcode)

def wait_threads():
    """
    Wait until all active threads are done
    we usually don't want to use this, because
    it will also wait for infinite transport threads
    """
    try:
        while threading.activeCount() > 1:
            lg.debug("Waiting for %i threads" % (threading.activeCount() - 1))
            time.sleep(0.5)
    except (KeyboardInterrupt, SystemExit, EOFError):
        threads = threading.enumerate()
        for thread in threads:
            lg.debug("Killing thread %s" % thread.getName())
            try:
                thread._Thread__stop()
            except Exception as e:
                lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e))
    finally:
        sshCleanup()

def run(machines, args):
    lg.debug("Hosts: %s" % machines)

    pool = []

    if isinstance(machines, list):
        tmp = {}
        for host in machines:
            tmp[host] = {
                'hostname'    : host,
                'ip_public' : host,
                'ip'        : host,
                'instance_id' : None,
            }
        machines = tmp

    for hostname in machines.iterkeys():
        node = machines[hostname]
        if args.domain:
            node['connect'] = "%s.%s" % (hostname, args.domain)
        else:
            node['connect'] = hostname

        if args.list_nodes:
            if not args.ip:
                print "{0:<25}{1}".format(hostname, machines[hostname]['instance_id'])
            else:
                print "{0:<25}{1}{2:>20}".format(hostname, machines[hostname]['instance_id'], node['connect'])
        else:
            if not args.command:
                lg.error("'command' option have to be set")
                sys.exit(1)

            if args.upload:
                t = threading.Thread(target=uploadFile, args=(node, args.upload, args.command, args.sshUser, args.sshKeyFile))
            else:
                if args.system_ssh:
                    lg.warn('Using system SSH is obsolete and may be buggy. Avoid using this option!')
                    t = threading.Thread(name=hostname, target=runSSH, args=(node['connect'], args.command, args.sshUser, args.sshKeyFile))
                else:
                    t = threading.Thread(name=hostname, target=runRemote, args=(node, args.command, args.sshUser, args.sshKeyFile))

            t.start()
            pool.append(t)
    return pool

def runSSH(name, command, user, keyFile=None):
    """
        execute command on <name> host
    """
    global exitcode
    threadLimiter.acquire()
    try:
        # popen 0 - last argument means unbuffered output
        lg.debug("Execute '%s' on '%s'" % (command, name))
        command = 'source /etc/profile >/dev/null;%s' % command
        fh = os.popen(('ssh -qAY -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o IdentityFile=%s %s@%s -- \"' % (keyFile, user, name)) + command  + ' 2>&1\"', 'r', 0)
        for line in fh:
            sys.stdout.write ("%s: %s" % (name, line))
            sys.stdout.flush()
        fh.close()
        print "close: %s os.status" % (fh.close(), os.WEXITSTATUS)
        if os.WEXITSTATUS != 0:
            exitcode = os.WEXITSTATUS
    finally:
        threadLimiter.release()

def sshCleanup():
    for ssh in sshPool.keys():
        lg.debug("Closing connection to %s" % ssh)
        sshPool[ssh].close()

def runRemote(node, command, user, keyFile=None):
    """
        execute command on <node> host with Paramiko
    """
    global sshPool
    global sshDone
    global exitcode

    threadLimiter.acquire()
    try:
        lg.debug("Execute '%s' on '%s'" % (command, node))
        command = 'source /etc/profile >/dev/null;%s' % command

        connect = node['connect']

        if connect in sshFailed:
            return False

        try: sshPool[connect]
        except:
            sshPool[connect] = paramiko.SSHClient()
            sshPool[connect].set_missing_host_key_policy(paramiko.AutoAddPolicy())
            try:
                sshPool[connect].load_system_host_keys()
            except paramiko.SSHException as e:
                lg.error("Can't load system known hosts: %s" % e)
                sshFailed.append(connect)
                return False

            try:
                sshPool[connect].connect(connect, username=user, timeout=5, key_filename=keyFile)
            except KeyboardInterrupt:
                lg.info("Interrupted")
                sys.exit(0)
            except (socket.gaierror, socket.error) as e:
                lg.error("Can't connect to %s (%s): %s" % (connect, node['hostname'], e))
                sshFailed.append(connect)
                return False
            except socket.timeout as e:
                lg.error("Timeout during connecting to %s (%s)" % (connect, node['hostname']))
                sshFailed.append(connect)
                return False
            except paramiko.SSHException as e:
                lg.error("Can't connect to %s (%s) as user %s: %s" % (connect, node['hostname'], user, e))
                sshFailed.append(connect)
                return False

        trans = sshPool[connect].get_transport()
        if not trans:
            lg.error("Can't get transport for connection %s (%s)", (connect, node['hostname']))
            sshFailed.append(connect)
            return False
        chan = trans.open_session()
        if not chan:
            lg.error("Connection to %s (%s) no longer active", (connect, node['hostname']))
            sshFailed.append(connect)
            return False
        chan.get_pty()

        # Timeout 5 seconds for first command
        # to test connection
        chan.settimeout(5)

        try:
            output = chan.makefile()
            chan.exec_command('hostname')
            for line in output:
                if command != 'hostname':
                    lg.debug("Connected to %s (hostname %s)" % (connect, line.replace("\r\n", "\n")))
                else:
                    sys.stdout.write("%s: %s" % (connect, line.replace("\r\n", "\n")))
                    sys.stdout.flush()
                    return True
        except socket.timeout:
            lg.error("Timeout during communication with %s (%s)" % (connect, node['hostname']))
            chan.close()
            return False

        # Channel without timeout for our command
        chan = trans.open_session()
        chan.settimeout(None)
        chan.get_pty()
        output = chan.makefile()
        chan.exec_command(command)

        for line in output:
            sys.stdout.write("%s: %s" % (node['hostname'], line.replace("\r\n", "\n")))
            sys.stdout.flush()

        # Cleanup
        chan.close()
        status = chan.recv_exit_status()
        lg.debug("Exit status: %s" % status)
        if status != -1 and status != 0 :
            exitcode = status
    finally:
        sshDone.append(node['hostname'])
        threadLimiter.release()

def uploadFile(node, localFile, remoteFile, user, keyFile=None):
    """
        upload file to remote host with Paramiko
    """
    connect = node['connect']

    lg.debug("Upload '%s' on '%s:%s'" % (localFile, node['hostname'], remoteFile))

    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.load_system_host_keys()

    try:
        ssh.connect(connect, username=user, timeout=5, key_filename=keyFile)
    except:
        lg.warn("Can't connect to %s (%s)" % (connect, node['hostname']))
        return False

    ftp = ssh.open_sftp()
    try:
        ftp.put(localFile, remoteFile)
    except (OSError, IOError), e:
        lg.error(e)
        sys.exit(1)

    ftp.close()
    ssh.close()

def siginfo_handler(signum, frame):
    if threading.activeCount() > 1:
        nodes_active = []
        for thread in threading.enumerate():
            if thread.getName() != 'MainThread':
                nodes_active.append(thread.getName())

        print "--"
        print "Done: %s/%s" % (len(sshDone), len(sshHosts))
        print "SSH connections: %s" % len(sshPool)
        print "Connections failed: %s" % len(sshFailed)
        print "Threads count: %s" % threading.activeCount()
        print "Thread names: %s" % ','.join(nodes_active)
        print "--"

if __name__ == '__main__':
    main()