#!/usr/bin/env python3
# Copyright (C) 2023 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import filecmp
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile

PERFETTO_DEFAULTS = {
    'guard_strip_prefix': 'PROTOS_PERFETTO_',
    'guard_add_prefix': 'INCLUDE_PERFETTO_PUBLIC_PROTOS_',
    'path_strip_prefix': 'protos/perfetto',
    'path_add_prefix': 'perfetto/public/protos',
    'include_prefix': 'include/',
}

TEST_DEFAULTS = {
    'guard_strip_prefix': 'SRC_PROTOZERO_TEST_EXAMPLE_PROTO_',
    'guard_add_prefix': 'SRC_SHARED_LIB_TEST_PROTOS_',
    'path_strip_prefix': 'src/protozero/test/example_proto',
    'path_add_prefix': 'src/shared_lib/test/protos',
}

PROD_FILES = [
    'protos/perfetto/common/builtin_clock.proto',
    'protos/perfetto/common/data_source_descriptor.proto',
    'protos/perfetto/common/semantic_type.proto',
    'protos/perfetto/config/data_source_config.proto',
    'protos/perfetto/config/trace_config.proto',
    'protos/perfetto/config/track_event/track_event_config.proto',
    'protos/perfetto/trace/android/android_track_event.proto',
    'protos/perfetto/trace/clock_snapshot.proto',
    'protos/perfetto/trace/interned_data/interned_data.proto',
    'protos/perfetto/trace/test_event.proto',
    'protos/perfetto/trace/trace.proto',
    'protos/perfetto/trace/trace_packet.proto',
    'protos/perfetto/trace/track_event/counter_descriptor.proto',
    'protos/perfetto/trace/track_event/debug_annotation.proto',
    'protos/perfetto/trace/track_event/track_descriptor.proto',
    'protos/perfetto/trace/track_event/track_event.proto',
]

TEST_FILES = [
    'src/protozero/test/example_proto/extensions.proto',
    'src/protozero/test/example_proto/library.proto',
    'src/protozero/test/example_proto/library_internals/galaxies.proto',
    'src/protozero/test/example_proto/other_package/test_messages.proto',
    'src/protozero/test/example_proto/subpackage/test_messages.proto',
    'src/protozero/test/example_proto/test_messages.proto',
    'src/protozero/test/example_proto/upper_import.proto',
]

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
IS_WIN = sys.platform.startswith('win')

SCRIPT_PATH = 'tools/gen_c_protos'


def protozero_c_plugin_path(out_directory):
  path = os.path.join(out_directory,
                      'protozero_c_plugin') + ('.exe' if IS_WIN else '')
  assert os.path.isfile(path)
  return path


def protoc_path(out_directory):
  path = os.path.join(out_directory, 'protoc') + ('.exe' if IS_WIN else '')
  assert os.path.isfile(path)
  return path


def call(cmd, *args):
  path = os.path.join('tools', cmd)
  command = ['python3', path] + list(args)
  print('Running', ' '.join(command))
  try:
    subprocess.check_call(command, cwd=ROOT_DIR)
  except subprocess.CalledProcessError as e:
    assert False, 'Command: {} failed'.format(' '.join(command))


# Reformats filename
def clang_format(filename):
  path = os.path.join(ROOT_DIR, 'third_party', 'clang-format',
                      'clang-format') + ('.exe' if IS_WIN else '')
  assert os.path.isfile(
      path), "clang-format not found. Run tools/install-build-deps"
  subprocess.check_call([
      path, '--style=file:{}'.format(os.path.join(ROOT_DIR, '.clang-format')),
      '-i', filename
  ],
                        cwd=ROOT_DIR)


# Transforms filename extension like the ProtoZero C plugin
def transform_extension(filename):
  old_suffix = ".proto"
  new_suffix = ".pzc.h"
  if filename.endswith(old_suffix):
    return filename[:-len(old_suffix)] + new_suffix
  return filename


def generate(source, outdir, protoc_path, protozero_c_plugin_path,
             guard_strip_prefix, guard_add_prefix, path_strip_prefix,
             path_add_prefix):
  options = {
      'guard_strip_prefix': guard_strip_prefix,
      'guard_add_prefix': guard_add_prefix,
      'path_strip_prefix': path_strip_prefix,
      'path_add_prefix': path_add_prefix,
      'invoker': SCRIPT_PATH,
  }
  serialized_options = ','.join(
      ['{}={}'.format(name, value) for name, value in options.items()])
  subprocess.check_call([
      protoc_path,
      '--proto_path=.',
      '--plugin=protoc-gen-plugin={}'.format(protozero_c_plugin_path),
      '--plugin_out={}:{}'.format(serialized_options, outdir),
      source,
  ],
                        cwd=ROOT_DIR)


# Given filename, the path of a header generated by the ProtoZero C plugin,
# returns the path where the header should go in the public include directory.
# Example
#
# include_path_for("protos/perfetto/trace/trace.pzc.h") ==
# "include/perfetto/public/protos/trace/trace.pzc.h"
def include_path_for(filename):
  return os.path.join('include', 'perfetto', 'public', 'protos',
                      *pathlib.Path(transform_extension(filename)).parts[2:])


def main():
  parser = argparse.ArgumentParser(
      formatter_class=argparse.RawDescriptionHelpFormatter,
      description="""Generates C API header files (.pzc.h) from proto definitions.

This script generates ProtoZero C headers for both standard Perfetto protos
and custom proto files.

OUTPUT:
  Generated .pzc.h files are placed in include/perfetto/public/protos/
  mirroring the structure from protos/perfetto/:
    protos/perfetto/trace/trace.proto
      -> include/perfetto/public/protos/trace/trace.pzc.h""",
      epilog="""Example workflow:
  # Generate custom protos
  tools/gen_c_protos out/linux_clang_release \\
      --proto protos/perfetto/trace/android/your_custom.proto \\
      --proto protos/perfetto/common/your_common.proto
""")

  parser.add_argument(
      'OUT', help='Build output directory (e.g., out/linux_clang_release)')
  parser.add_argument(
      '--proto',
      action='append',
      help='Additional .proto files to generate. Path must be relative to the repository root (e.g., protos/perfetto/common/builtin_clock.proto)',
      default=[])
  parser.add_argument(
      '--check-only',
      action='store_true',
      help='Verify generated files match existing without writing')

  args = parser.parse_args()

  out = args.OUT

  if args.proto:
    source_files = [{
        'files': args.proto,
        **PERFETTO_DEFAULTS,
    }]
  else:
    source_files = [{
        'files': PROD_FILES,
        **PERFETTO_DEFAULTS,
    }, {
        'files': TEST_FILES,
        **TEST_DEFAULTS,
    }]

  call('ninja', '-C', out, 'protoc', 'protozero_c_plugin')

  try:
    with tempfile.TemporaryDirectory() as tmpdirname:
      for sources in source_files:
        for source in sources['files']:
          generate(
              source,
              tmpdirname,
              protoc_path(out),
              protozero_c_plugin_path(out),
              guard_strip_prefix=sources['guard_strip_prefix'],
              guard_add_prefix=sources['guard_add_prefix'],
              path_strip_prefix=sources['path_strip_prefix'],
              path_add_prefix=sources['path_add_prefix'],
          )

          tmpfilename = os.path.join(tmpdirname, transform_extension(source))
          clang_format(tmpfilename)
          if source.startswith(sources['path_strip_prefix']):
            targetfilename = source[len(sources['path_strip_prefix']):]
          else:
            targetfilename = source

          targetfilename = sources['path_add_prefix'] + targetfilename

          if 'include_prefix' in sources:
            targetfilename = os.path.join(sources['include_prefix'],
                                          targetfilename)
          targetfilename = transform_extension(targetfilename)

          if args.check_only:
            if not filecmp.cmp(tmpfilename, targetfilename):
              raise AssertionError('Target {} does not match', targetfilename)
            print('Verified {}'.format(targetfilename))
          else:
            os.makedirs(os.path.dirname(targetfilename), exist_ok=True)
            shutil.copyfile(tmpfilename, targetfilename)
            print('Generated {}'.format(targetfilename))

  except AssertionError as e:
    if not str(e):
      raise
    print('Error: {}'.format(e))
    return 1


if __name__ == '__main__':
  exit(main())
