
import math
import random
from collections import Counter

def upscale2x(map, old_size, mode='nearest_neighbour'):
    new_size = old_size*2
    new_map = []

    def get_index(x, y, size):
        return min(max(math.floor(x), 0), size-1) + min(max(math.floor(y), 0), size-1) * size

    match mode:
        case 'nearest_neighbour':
            for y in range(new_size):
                for x in range(new_size):
                    i = math.floor(x/2) + math.floor(y/2) * old_size
                    new_map.append(map[i])

        case 'average':
            
            # https://www.researchgate.net/figure/Discrete-approximation-of-the-Gaussian-kernels-3x3-5x5-7x7_fig2_325768087
            KERNEL = [
                [1,4,7,4,1],
                [4,16,26,16,4],
                [7,26,41,26,7],
                [4,16,26,16,4],
                [1,4,7,4,1],
            ]


            for y in range(new_size):
                for x in range(new_size):
                    # could be parallelised of course
                    # get adj pixels
                    val = 0

                    val += 1 * map[get_index((x-2)/2, (y+2)/2, old_size)]
                    val += 4 * map[get_index((x-1)/2, (y+2)/2, old_size)]
                    val += 7 * map[get_index((x)/2, (y+2)/2, old_size)]
                    val += 4 * map[get_index((x+1)/2, (y+2)/2, old_size)]
                    val += 1 * map[get_index((x+2)/2, (y+2)/2, old_size)]

                    val += 4 * map[get_index((x-2)/2, (y+1)/2, old_size)]
                    val += 16 * map[get_index((x-1)/2, (y+1)/2, old_size)]
                    val += 26.001 * map[get_index((x)/2, (y+1)/2, old_size)]
                    val += 16 * map[get_index((x+1)/2, (y+1)/2, old_size)]
                    val += 4 * map[get_index((x+2)/2, (y+1)/2, old_size)]

                    val += 7 * map[get_index((x-2)/2, (y)/2, old_size)]
                    val += 26.001 * map[get_index((x-1)/2, (y)/2, old_size)]
                    val += 41.11 * map[get_index((x)/2, (y)/2, old_size)]
                    val += 26.001 * map[get_index((x+1)/2, (y)/2, old_size)]
                    val += 7 * map[get_index((x+2)/2, (y)/2, old_size)]

                    val += 4 * map[get_index((x-2)/2, (y-1)/2, old_size)]
                    val += 16 * map[get_index((x-1)/2, (y-1)/2, old_size)]
                    val += 26.001 * map[get_index((x)/2, (y-1)/2, old_size)]
                    val += 16 * map[get_index((x+1)/2, (y-1)/2, old_size)]
                    val += 4 * map[get_index((x+2)/2, (y-1)/2, old_size)]

                    val += 1 * map[get_index((x-2)/2, (y-2)/2, old_size)]
                    val += 4 * map[get_index((x-1)/2, (y-2)/2, old_size)]
                    val += 7 * map[get_index((x)/2, (y-2)/2, old_size)]
                    val += 4 * map[get_index((x+1)/2, (y-2)/2, old_size)]
                    val += 1 * map[get_index((x+2)/2, (y-2)/2, old_size)]

                    new_map.append(int(math.floor(val/273)))

        case 'random': # gets all in a 3x3 grid, picks randomly from it
            for y in range(new_size):
                for x in range(new_size):

                    all_values = []

                    all_values.append(map[get_index((x-1)/2, (y+1)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y+1)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y+1)/2, old_size)])

                    all_values.append(map[get_index((x-1)/2, (y)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y)/2, old_size)])

                    all_values.append(map[get_index((x-1)/2, (y-1)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y-1)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y-1)/2, old_size)])
                    
                    new_map.append(random.choice(all_values))
        
        case 'mode': # gets all in a 3x3 grid, picks mode
            for y in range(new_size):
                for x in range(new_size):

                    all_values = []

                    all_values.append(map[get_index((x-1)/2, (y+1)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y+1)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y+1)/2, old_size)])

                    all_values.append(map[get_index((x-1)/2, (y)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y)/2, old_size)])

                    all_values.append(map[get_index((x-1)/2, (y-1)/2, old_size)])
                    all_values.append(map[get_index((x)/2, (y-1)/2, old_size)])
                    all_values.append(map[get_index((x+1)/2, (y-1)/2, old_size)])
                    
                    def get_mode(lst):
                        data = Counter(lst)
                        return data.most_common(1)[0][0]
                    
                    new_map.append(get_mode(all_values))



    return new_map


if __name__ == '__main__':
    """heightMap = (0,10,20,30,
                4,5,6,7,
                80,90,100,110,
                12,13,14,15)


    print(upscale2x(heightMap, 4, 'average'))"""

    matMap = (
        'a', 'b', 'c', 'd',
        'a', 'b', 'c', 'd',
        'a', 'a', 'c', 'c',
        'x', 'a', 'a', 'c',
    )

    print(upscale2x(matMap, 4, 'mode'))

    #print(sum(upscale2x((1000, 2000, 1000, 1000), 2, 'average')))