#!/usr/bin/python3

import argparse
import sys
import io
from azuremetadata import azuremetadatautils, azuremetadata


class PreserveArgumentOrder(argparse.Action):
    def __call__(self, parser, namespace, value, option_string=None):
        if 'ordered_args' not in namespace:
            setattr(namespace, 'ordered_args', [])
        namespace.ordered_args.append((self.dest, value))


api_version_parser = argparse.ArgumentParser(add_help=False)
api_version_parser.add_argument('-a', '--api', nargs='?', const=None)
api_version_parser.add_argument('--device', nargs='?', const=None)
api_args, _ = api_version_parser.parse_known_args()

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('-h', '--help', action="store_true", help="Display help")
parser.add_argument('-x', '--xml', action="store_true", help="Output as XML")
parser.add_argument('-j', '--json', action="store_true", help="Output as JSON")
parser.add_argument('-o', '--output', help="Output file path")
parser.add_argument('-a', '--api', help="API version")
parser.add_argument('--device', help="Device to read disk tag from (default: root device)", nargs='?')

with io.StringIO() as string_io:
    parser.print_help(string_io)
    help_header = string_io.getvalue()

try:
    metadata = azuremetadata.AzureMetadata(api_args.api)

    data = metadata.get_all()
    data['billingTag'] = metadata.get_disk_tag(api_args.device)

    util = azuremetadatautils.AzureMetadataUtils(data)
    for key in util.available_params.keys():
        parser.add_argument('--{}'.format(key), nargs='?', type=int, const=True, action=PreserveArgumentOrder)

    args = parser.parse_args()
    ordered_args = getattr(args, 'ordered_args', [])

    if args.help:
        print(help_header)
        print("\nquery arguments:")
        util.print_help()
        exit()

    fh = None
    if args.output:
        fh = open(args.output, 'w')

    try:
        if not len(ordered_args):
            util.print_pretty(print_xml=args.xml, print_json=args.json, file=fh)
        else:
            result = util.query(ordered_args)
            for item in result:
                util.print_pretty(print_xml=args.xml, print_json=args.json, data=item, file=fh)
    finally:
        if fh:
            fh.close()

except azuremetadatautils.QueryException as e:
    print(e, file=sys.stderr)
    exit(1)
