// Adapted from
// https://github.com/leandromoreira/ffmpeg-libav-tutorial/blob/master/2_remuxing.c
// based on https://ffmpeg.org/doxygen/trunk/remuxing_8c-example.html
// FIXME: looks ugly, further refactor needed :(
#include <libavformat/avformat.h>
#include <libavutil/timestamp.h>
#include <stdio.h>

#include "../logger.h"
#include "ffmpeg.h"

static status_t *stat_g;

#ifdef DEBUG
static void log_packet(const AVFormatContext *fmt_ctx, const AVPacket *pkt,
                       const char *tag) {
  AVRational *time_base = &fmt_ctx->streams[pkt->stream_index]->time_base;

  printf("%s: pts:%s pts_time:%s dts:%s dts_time:%s duration:%s "
         "duration_time:%s stream_index:%d\n",
         tag, av_ts2str(pkt->pts), av_ts2timestr(pkt->pts, time_base),
         av_ts2str(pkt->dts), av_ts2timestr(pkt->dts, time_base),
         av_ts2str(pkt->duration), av_ts2timestr(pkt->duration, time_base),
         pkt->stream_index);
}
#else
static void log_packet(const AVFormatContext *fmt_ctx, const AVPacket *pkt,
                       const char *tag) {
  return;
}
#endif

void setup_stat(status_t *stat) { stat_g = stat; }

int merge_av(const char *videofn, const char *audiofn, const char *outfn) {
  AVFormatContext *input1_format_context = NULL, *input2_format_context = NULL,
                  *output_format_context = NULL;
  AVPacket packet;
  int ret, i;
  int stream_index = 0;
  int *streams_list = NULL;
  int number_of_streams = 0;

  AVDictionary *opts = NULL;
  /* Support fragmented MP4
   * https://developer.mozilla.org/en-US/docs/Web/API/Media_Source_Extensions_API/Transcoding_assets_for_MSE
   */
  av_dict_set(&opts, "movflags", "frag_keyframe+empty_moov+default_base_moof",
              0);
  /* Enable all protocols support */
  av_dict_set(&opts, "protocol_whitelist",
              "file,crypto,data,http,https,tcp,tls", 0);

  if ((ret = avformat_open_input(&input1_format_context, videofn, NULL, NULL)) <
      0) {
    append_log("Could not open input file '%s'\n", videofn);
    goto end;
  }
  if ((ret = avformat_open_input(&input2_format_context, audiofn, NULL, NULL)) <
      0) {
    append_log("Could not open input file '%s'\n", audiofn);
    goto end;
  }
  if ((ret = avformat_find_stream_info(input1_format_context, NULL)) < 0) {
    append_log("Failed to retrieve input stream information\n");
    goto end;
  }
  if ((ret = avformat_find_stream_info(input2_format_context, NULL)) < 0) {
    append_log("Failed to retrieve input stream information\n");
    goto end;
  }

  stat_g->total = input1_format_context->duration / AV_TIME_BASE;

  avformat_alloc_output_context2(&output_format_context, NULL, NULL, outfn);
  if (!output_format_context) {
    append_log("Could not create output context\n");
    ret = AVERROR_UNKNOWN;
    goto end;
  }

  number_of_streams =
      input1_format_context->nb_streams + input2_format_context->nb_streams;
  streams_list = av_malloc_array(number_of_streams, sizeof(*streams_list));

  if (!streams_list) {
    ret = AVERROR(ENOMEM);
    goto end;
  }

  for (i = 0; i < input1_format_context->nb_streams; i++) {
    AVStream *out_stream;
    AVStream *in_stream = input1_format_context->streams[i];
    AVCodecParameters *in_codecpar = in_stream->codecpar;
    if (in_codecpar->codec_type != AVMEDIA_TYPE_AUDIO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_VIDEO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_SUBTITLE) {
      streams_list[i] = -1;
      continue;
    }
    streams_list[i] = stream_index++;
    out_stream = avformat_new_stream(output_format_context, NULL);
    if (!out_stream) {
      append_log("Failed allocating output stream\n");
      ret = AVERROR_UNKNOWN;
      goto end;
    }
    ret = avcodec_parameters_copy(out_stream->codecpar, in_codecpar);
    if (ret < 0) {
      append_log("Failed to copy codec parameters\n");
      goto end;
    }
  }
  for (i = 0; i < input2_format_context->nb_streams; i++) {
    AVStream *out_stream;
    AVStream *in_stream = input2_format_context->streams[i];
    AVCodecParameters *in_codecpar = in_stream->codecpar;
    if (in_codecpar->codec_type != AVMEDIA_TYPE_AUDIO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_VIDEO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_SUBTITLE) {
      streams_list[i + stream_index] = -1;
      continue;
    }
    streams_list[i + stream_index] = stream_index;
    stream_index++;
    out_stream = avformat_new_stream(output_format_context, NULL);
    if (!out_stream) {
      append_log("Failed allocating output stream\n");
      ret = AVERROR_UNKNOWN;
      goto end;
    }
    ret = avcodec_parameters_copy(out_stream->codecpar, in_codecpar);
    if (ret < 0) {
      append_log("Failed to copy codec parameters\n");
      goto end;
    }
  }
#ifdef DEBUG
  av_dump_format(input1_format_context, 0, videofn, 0);
  av_dump_format(input2_format_context, 0, audiofn, 0);
  av_dump_format(output_format_context, 0, outfn, 1);
#endif
  if (!(output_format_context->oformat->flags & AVFMT_NOFILE)) {
    ret = avio_open(&output_format_context->pb, outfn, AVIO_FLAG_WRITE);
    if (ret < 0) {
      append_log("Could not open output file '%s'", outfn);
      goto end;
    }
  }

  ret = avformat_write_header(output_format_context, &opts);
  if (ret < 0) {
    append_log("Error occurred when opening output file\n");
    goto end;
  }
  while (1) {
    AVStream *in_stream, *out_stream;
    ret = av_read_frame(input1_format_context, &packet);
    if (ret < 0)
      break;
    in_stream = input1_format_context->streams[packet.stream_index];
    if (packet.stream_index >= number_of_streams ||
        streams_list[packet.stream_index] < 0) {
      av_packet_unref(&packet);
      continue;
    }
    packet.stream_index = streams_list[packet.stream_index];
    out_stream = output_format_context->streams[packet.stream_index];
    /* copy packet */
    // log_packet(input1_format_context, &packet, "in");
    packet.pts = av_rescale_q_rnd(packet.pts, in_stream->time_base,
                                  out_stream->time_base,
                                  AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX);
    packet.dts = av_rescale_q_rnd(packet.dts, in_stream->time_base,
                                  out_stream->time_base,
                                  AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX);
    packet.duration = av_rescale_q(packet.duration, in_stream->time_base,
                                   out_stream->time_base);
    packet.pos = -1;
    // log_packet(output_format_context, &packet, "out");
    stat_g->cur = (stat_g->cur <= stat_g->total)
                      ? av_q2d(in_stream->time_base) * packet.pts
                      : stat_g->total;

    ret = av_interleaved_write_frame(output_format_context, &packet);
    if (ret < 0) {
      append_log("Error muxing packet\n");
      break;
    }
    av_packet_unref(&packet);
  }
  while (1) {
    AVStream *in_stream, *out_stream;
    ret = av_read_frame(input2_format_context, &packet);
    if (ret < 0)
      break;
    in_stream = input2_format_context->streams[packet.stream_index];
    if (packet.stream_index >= number_of_streams ||
        streams_list[packet.stream_index + input1_format_context->nb_streams] <
            0) {
      av_packet_unref(&packet);
      continue;
    }
    packet.stream_index =
        streams_list[packet.stream_index + input1_format_context->nb_streams];
    out_stream = output_format_context->streams[packet.stream_index];
    /* copy packet */
    // log_packet(input2_format_context, &packet, "in");
    packet.pts = av_rescale_q_rnd(packet.pts, in_stream->time_base,
                                  out_stream->time_base,
                                  AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX);
    packet.dts = av_rescale_q_rnd(packet.dts, in_stream->time_base,
                                  out_stream->time_base,
                                  AV_ROUND_NEAR_INF | AV_ROUND_PASS_MINMAX);
    packet.duration = av_rescale_q(packet.duration, in_stream->time_base,
                                   out_stream->time_base);
    packet.pos = -1;
    // log_packet(output_format_context, &packet, "out");
    stat_g->cur = (stat_g->cur <= stat_g->total)
                      ? av_q2d(in_stream->time_base) * packet.pts
                      : stat_g->total;
    ret = av_interleaved_write_frame(output_format_context, &packet);
    if (ret < 0) {
      append_log("Error muxing packet\n");
      break;
    }
    av_packet_unref(&packet);
  }
  av_write_trailer(output_format_context);
end:
  avformat_close_input(&input1_format_context);
  avformat_close_input(&input2_format_context);
  /* close output */
  if (output_format_context &&
      !(output_format_context->oformat->flags & AVFMT_NOFILE))
    avio_closep(&output_format_context->pb);
  avformat_free_context(output_format_context);
  av_freep(&streams_list);
  if (ret < 0 && ret != AVERROR_EOF) {
    append_log("Error occurred: %s\n", av_err2str(ret));
    return 1;
  }

  // Delete seperate files
  if (remove(videofn) != 0) {
    append_log("Error deleting partial file %s\n", videofn);
  }
  if (remove(audiofn) != 0) {
    append_log("Error deleting partial file %s\n", audiofn);
  }
  return 0;
}

int remux(const char *in_filename, const char *out_filename) {
  const AVOutputFormat *ofmt = NULL;
  AVFormatContext *ifmt_ctx = NULL, *ofmt_ctx = NULL;
  AVPacket *pkt = NULL;
  int ret, i;
  int stream_index = 0;
  int *stream_mapping = NULL;
  int stream_mapping_size = 0;

  pkt = av_packet_alloc();
  if (!pkt) {
    fprintf(stderr, "Could not allocate AVPacket\n");
    return 1;
  }

  AVDictionary *opts = NULL;
  av_dict_set(&opts, "protocol_whitelist",
              "concat,file,http,https,tcp,tls,crypto", 0);

  if ((ret = avformat_open_input(&ifmt_ctx, in_filename, 0, &opts)) < 0) {
    fprintf(stderr, "Could not open input file '%s'\n", in_filename);
    goto end;
  }

  if ((ret = avformat_find_stream_info(ifmt_ctx, 0)) < 0) {
    fprintf(stderr, "Failed to retrieve input stream information.\n");
    goto end;
  }

  DEBUG_PRINT("duration: %.2Lf\n",
              (long double)ifmt_ctx->duration / AV_TIME_BASE);
  stat_g->total = ifmt_ctx->duration / AV_TIME_BASE;

  av_dump_format(ifmt_ctx, 0, in_filename, 0);

  avformat_alloc_output_context2(&ofmt_ctx, NULL, NULL, out_filename);
  if (!ofmt_ctx) {
    fprintf(stderr, "Could not create output context.\n");
    ret = AVERROR_UNKNOWN;
    goto end;
  }

  stream_mapping_size = ifmt_ctx->nb_streams;
  stream_mapping = av_calloc(stream_mapping_size, sizeof(*stream_mapping));
  if (!stream_mapping) {
    ret = AVERROR(ENOMEM);
    goto end;
  }

  ofmt = ofmt_ctx->oformat;

  for (i = 0; i < ifmt_ctx->nb_streams; i++) {
    AVStream *out_stream;
    AVStream *in_stream = ifmt_ctx->streams[i];
    AVCodecParameters *in_codecpar = in_stream->codecpar;

    if (in_codecpar->codec_type != AVMEDIA_TYPE_AUDIO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_VIDEO &&
        in_codecpar->codec_type != AVMEDIA_TYPE_SUBTITLE) {
      stream_mapping[i] = -1;
      continue;
    }

    stream_mapping[i] = stream_index++;

    out_stream = avformat_new_stream(ofmt_ctx, NULL);
    if (!out_stream) {
      fprintf(stderr, "Failed allocating output stream\n");
      ret = AVERROR_UNKNOWN;
      goto end;
    }

    ret = avcodec_parameters_copy(out_stream->codecpar, in_codecpar);
    if (ret < 0) {
      fprintf(stderr, "Failed to copy codec parameters\n");
      goto end;
    }
    out_stream->codecpar->codec_tag = 0;
  }
  av_dump_format(ofmt_ctx, 0, out_filename, 1);

  if (!(ofmt->flags & AVFMT_NOFILE)) {
    ret = avio_open(&ofmt_ctx->pb, out_filename, AVIO_FLAG_WRITE);
    if (ret < 0) {
      fprintf(stderr, "Could not open output file '%s'", out_filename);
      goto end;
    }
  }

  ret = avformat_write_header(ofmt_ctx, &opts);
  if (ret < 0) {
    fprintf(stderr, "Error occurred when opening output file\n");
    goto end;
  }

  while (1) {
    AVStream *in_stream, *out_stream;

    ret = av_read_frame(ifmt_ctx, pkt);
    if (ret < 0)
      break;

    in_stream = ifmt_ctx->streams[pkt->stream_index];
    if (pkt->stream_index >= stream_mapping_size ||
        stream_mapping[pkt->stream_index] < 0) {
      av_packet_unref(pkt);
      continue;
    }

    pkt->stream_index = stream_mapping[pkt->stream_index];
    out_stream = ofmt_ctx->streams[pkt->stream_index];
    // log_packet(ifmt_ctx, pkt, "in");

    /* copy packet */
    av_packet_rescale_ts(pkt, in_stream->time_base, out_stream->time_base);
    pkt->pos = -1;
    // log_packet(ofmt_ctx, pkt, "out");
    stat_g->cur = (stat_g->cur <= stat_g->total)
                      ? av_q2d(in_stream->time_base) * pkt->pts
                      : stat_g->total;
    ret = av_interleaved_write_frame(ofmt_ctx, pkt);
    /* pkt is now blank (av_interleaved_write_frame() takes ownership of
     * its contents and resets pkt), so that no unreferencing is necessary.
     * This would be different if one used av_write_frame(). */
    if (ret < 0) {
      fprintf(stderr, "Error muxing packet\n");
      break;
    }
  }

  av_write_trailer(ofmt_ctx);
end:
  av_packet_free(&pkt);

  avformat_close_input(&ifmt_ctx);

  /* close output */
  if (ofmt_ctx && !(ofmt->flags & AVFMT_NOFILE))
    avio_closep(&ofmt_ctx->pb);
  avformat_free_context(ofmt_ctx);

  av_freep(&stream_mapping);

  if (ret < 0 && ret != AVERROR_EOF) {
    fprintf(stderr, "Error occurred: %s\n", av_err2str(ret));
    return 1;
  }

  // Delete seperate files
  if (remove(in_filename) != 0) {
    append_log("Error deleting partial file %s\n", in_filename);
  }

  return 0;
}