import math
import mido
import sys
import random

SMOOTH_CCS = {1, 2, 11}
PERCENT = 100/127 * 3

def find_and_remove_micro_notes(mid, min_ticks=10):
    for track in mid.tracks:

        absolute_time = 0
        active = {}   # (note, channel) -> [(start_time, index)]
        notes = []     # (start, end, note, channel, idx_on, idx_off)

        for i, msg in enumerate(track):
            absolute_time += msg.time

            if msg.type == 'note_on' and msg.velocity > 0:
                key = (msg.note, msg.channel)
                active.setdefault(key, []).append((absolute_time, i))

            elif msg.type in ('note_off', 'note_on') and msg.velocity == 0:
                key = (msg.note, msg.channel)

                if key in active and active[key]:
                    start_time, i_on = active[key].pop()
                    notes.append((start_time, absolute_time, msg.note, msg.channel, i_on, i))

        to_remove = set()

        for a in notes:
            for b in notes:
                if a == b:
                    continue

                a_start, a_end, a_note, a_ch, a_on, a_off = a
                b_start, b_end, b_note, b_ch, *_ = b

                if a_note != b_note or a_ch != b_ch:
                    continue

                if b_start <= a_start and a_end <= b_end:
                    if (a_end - a_start) <= min_ticks:
                        to_remove.add(a_on)
                        to_remove.add(a_off)

        for i in sorted(to_remove, reverse=True):
            del track[i]

    return mid

def randomize_value(current_value):
    delta = random.uniform(-PERCENT, PERCENT)
    rounded = math.floor(current_value + delta + 0.5)
    return max(0, min(127, rounded))

def process_midi(args, randomize=False):
    extension = '_out.mid'
    if (randomize):
        extension = '_random.mid'

    if len(args) > 1:
        midifile = args[1]
        if len(args) > 2:
            outfile = args[2]
        else:
            outfile = midifile[:-4] + extension

        mid = mido.MidiFile(midifile)
        out = mido.MidiFile(type=mid.type, ticks_per_beat=mid.ticks_per_beat)

        find_and_remove_micro_notes(mid, min_ticks=10)
        for track in mid.tracks:
            new_track = mido.MidiTrack()
            out.tracks.append(new_track)

            abs_time = 0
            last_out_time = 0

            # per CC state
            state = {
                cc: {
                    "last_written_val": None,
                    "last_written_time": None,
                    "last_obs_val": None,
                    "direction": None
                } for cc in SMOOTH_CCS
            }

            def write_message(msg, event_abs):
                nonlocal last_out_time
                delta = event_abs - last_out_time
                last_out_time = event_abs
                new_track.append(msg.copy(time=delta))

            ccs = set()
            for msg in track:
                if msg.type == "control_change":
                    ccs.add(msg.control)

            track_channel = 0
            for msg in track:
                abs_time += msg.time
                if hasattr(msg, 'channel') and track_channel != msg.channel and msg.type == 'note_on':
                    track_channel = msg.channel
                # drop program_change
                if msg.type == "program_change":
                    continue

                # handle CC
                if msg.type == "control_change":
                    cc = msg.control
                    if cc == 2 and (1 not in ccs):
                        cc = 1

                    msg.control = cc
                    if cc not in SMOOTH_CCS:
                        # drop all unsupported CC
                        continue

                    v = msg.value
                    s = state[cc]

                    # First point for this CC?
                    if s["last_written_val"] is None:
                        write_message(msg, abs_time)
                        s["last_written_val"] = v
                        s["last_written_time"] = abs_time
                        s["last_obs_val"] = v
                        s["direction"] = None
                        continue

                    prev = s["last_obs_val"]
                    direction = s["direction"]

                    # Compute direction of observed change
                    new_dir = "up" if v > prev else "down" if v < prev else direction

                    # Decide whether to write
                    write_now = False

                    if v == prev:
                        pass
                    else:
                        if direction is not None and new_dir != direction:
                            write_now = True
                        elif abs(v - prev) > 1:
                            write_now = True
                        elif direction is None:
                            write_now = True

                    if write_now:
                        write_message(msg, abs_time)
                        s["last_written_val"] = v
                        s["last_written_time"] = abs_time
                        s["direction"] = None
                    else:
                        s["direction"] = new_dir

                    s["last_obs_val"] = v
                    continue
                if msg.type == "note_on" and msg.velocity > 0 and randomize:
                   msg.velocity = randomize_value(msg.velocity)

                write_message(msg, abs_time)
            if 11 not in ccs:
                msg = mido.Message('control_change', channel = track_channel, control = 11, value = 127, time = 1)
                write_message(msg, abs_time + 1)
            if 1 not in ccs:
                msg = mido.Message('control_change', channel = track_channel, control = 1, value = 127, time = 1)
                write_message(msg, abs_time + 1)


        out.save(outfile)
        print(f'File {outfile} has been written.')

if __name__ == '__main__':
    process_midi(sys.argv, True)