#include "graphics.h"
#include "gfx/bg_lab.h"
#include "gfx/bg_ice.h"
#include "gfx/bg_cave.h"
#include "gfx/bg_city.h"
#include "gfx/bg_mansion.h"
#include "gfx/bg_powerp.h"
#include "gfx/bg_special1.h"
#include "gfx/bg_special2.h"
#include "gfx/bg_woods.h"
#include "gfx/sprites_global.h"
#include "gfx/sprites_lab.h"
#include "gfx/sprites_ice.h"
#include "gfx/sprites_cave.h"
#include "gfx/sprites_city.h"
#include "gfx/sprites_mansion.h"
#include "gfx/sprites_powerp.h"
#include "gfx/sprites_special1.h"
#include "gfx/sprites_special2.h"
#include "gfx/sprites_woods.h"
#include <string.h>

// Shadow OAM for sprite management
OAMEntry shadowOAM[128];

// Storage for original palettes and state tracking
static u16 originalBgPal[16];
static u16 originalObjPal[256];
static int palettesStored = 0;
static WorldType previousWorldType = -1;
static OAMEntry previousOAM[128];
static int previousCoinCount = -1;
static int activeSprites[128];
static int activeSpriteCount = 0;

// Default palette for initialization
const u16 originalPal[16] __attribute__((aligned(4))) = {
    0x20A7, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
    0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
};

// Tile sets for background rendering
const u16 BG_MENU_TILE_1[4] = {0, 1, 4, 5};
const u16 BG_MENU_TILE_2[4] = {2, 3, 6, 7};
const u16 BG_MENU_TILE_3[4] = {8, 9, 12, 13};
const u16 BG_MENU_TILE_4[4] = {10, 11, 14, 15};

const u16 BG_TILE_1[4] = {0, 1, 4, 5};
const u16 BG_TILE_2[4] = {2, 3, 6, 7};
const u16 BG_TILE_3[4] = {8, 9, 12, 13};
const u16 BG_TILE_4[4] = {10, 11, 14, 15};

// Initialize graphics for the start screen
void initGraphics(void) {
    BG0CNT = 0x4100;

    for (int i = 0; i < 16; i++) {
        BG_PAL[i] = (i < bg_menuPalLen / 2) ? bg_menuPal[i] : originalPal[i];
    }

    volatile u16* vramTile = VRAM_BG;
    for (int i = 0; i < bg_menuTilesLen / 2; i++) {
        vramTile[i] = bg_menuTiles[i];
    }

    fillBackgroundWithTileset(BG_MENU_TILE_1, BG_MENU_TILE_2, BG_MENU_TILE_3, BG_MENU_TILE_4);

    for (int i = 0; i < sprites_globalPalLen / 2; i++) {
        PAL_OBJ[i] = sprites_globalPal[i];
    }
    PAL_OBJ[0] = 0;

    volatile u16* vramObj = VRAM_OBJ;
    for (int i = 0; i < sprites_globalTilesLen / 2; i++) {
        vramObj[i] = sprites_globalTiles[i];
    }

    for (int i = 0; i < 128; i++) {
        hideSprite(i);
        previousOAM[i].attr0 = 0x0200;
        activeSprites[i] = 0;
    }
    activeSpriteCount = 0;
}

// Initialize graphics for a game level
void initLevelGraphics(void) {
    BG0CNT = 0x4100;

    for (int i = 0; i < 16; i++) {
        BG_PAL[i] = 0x0000;
    }
    for (int i = 0; i < 256; i++) {
        PAL_OBJ[i] = 0x0000;
    }

    for (int i = 0; i < sprites_globalPalLen / 2; i++) {
        PAL_OBJ[i] = sprites_globalPal[i];
    }
    PAL_OBJ[0] = 0;

    volatile u16* vramObj = VRAM_OBJ;
    for (int i = 0; i < sprites_globalTilesLen / 2; i++) {
        vramObj[i] = sprites_globalTiles[i];
    }

    for (int i = 0; i < 128; i++) {
        hideSprite(i);
        previousOAM[i].attr0 = 0x0200;
        activeSprites[i] = 0;
    }
    activeSpriteCount = 0;
}

// Load background tileset based on world type
void loadBackgroundTileset(WorldType worldType) {
    REG_DISPCNT = 0;

    for (int i = 0; i < 16; i++) {
        BG_PAL[i] = 0x0000;
    }

    bool tilesNeedReloading = (worldType != previousWorldType);
    volatile u16* vramTile = VRAM_BG;

    switch (worldType) {
        case WORLD_LAB:
            for (int i = 0; i < bg_labPalLen / 2; i++) { BG_PAL[i] = bg_labPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_labTilesLen / 2; i++) { vramTile[i] = bg_labTiles[i]; }
            }
            break;
        case WORLD_ICE:
            for (int i = 0; i < bg_icePalLen / 2; i++) { BG_PAL[i] = bg_icePal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_iceTilesLen / 2; i++) { vramTile[i] = bg_iceTiles[i]; }
            }
            break;
        case WORLD_CAVE:
            for (int i = 0; i < bg_cavePalLen / 2; i++) { BG_PAL[i] = bg_cavePal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_caveTilesLen / 2; i++) { vramTile[i] = bg_caveTiles[i]; }
            }
            break;
        case WORLD_CITY:
            for (int i = 0; i < bg_cityPalLen / 2; i++) { BG_PAL[i] = bg_cityPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_cityTilesLen / 2; i++) { vramTile[i] = bg_cityTiles[i]; }
            }
            break;
        case WORLD_MANSION:
            for (int i = 0; i < bg_mansionPalLen / 2; i++) { BG_PAL[i] = bg_mansionPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_mansionTilesLen / 2; i++) { vramTile[i] = bg_mansionTiles[i]; }
            }
            break;
        case WORLD_POWERP:
            for (int i = 0; i < bg_powerpPalLen / 2; i++) { BG_PAL[i] = bg_powerpPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_powerpTilesLen / 2; i++) { vramTile[i] = bg_powerpTiles[i]; }
            }
            break;
        case WORLD_WOODS:
            for (int i = 0; i < bg_woodsPalLen / 2; i++) { BG_PAL[i] = bg_woodsPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_woodsTilesLen / 2; i++) { vramTile[i] = bg_woodsTiles[i]; }
            }
            break;
        case WORLD_SPECIAL1:
            for (int i = 0; i < bg_special1PalLen / 2; i++) { BG_PAL[i] = bg_special1Pal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_special1TilesLen / 2; i++) { vramTile[i] = bg_special1Tiles[i]; }
            }
            break;
        case WORLD_SPECIAL2:
            for (int i = 0; i < bg_special2PalLen / 2; i++) { BG_PAL[i] = bg_special2Pal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_special2TilesLen / 2; i++) { vramTile[i] = bg_special2Tiles[i]; }
            }
            break;
        default:
            for (int i = 0; i < bg_labPalLen / 2; i++) { BG_PAL[i] = bg_labPal[i]; }
            if (tilesNeedReloading) {
                for (int i = 0; i < bg_labTilesLen / 2; i++) { vramTile[i] = bg_labTiles[i]; }
            }
            break;
    }

    REG_DISPCNT = MODE_0 | BG0_ENABLE | OBJ_ENABLE | OBJ_1D_MAP;
    VBlankIntrWait();
}

// Load world-specific sprite tileset
void loadLevelTileset(WorldType worldType) {
    bool tilesNeedReloading = (worldType != previousWorldType);

    if (tilesNeedReloading) {
        volatile u16* vramObj = VRAM_OBJ + (sprites_globalTilesLen / 2);
        switch (worldType) {
            case WORLD_LAB:
                for (int i = 0; i < sprites_labTilesLen / 2; i++) {
                    vramObj[i] = sprites_labTiles[i];
                }
                break;
            case WORLD_ICE:
                for (int i = 0; i < sprites_iceTilesLen / 2; i++) {
                    vramObj[i] = sprites_iceTiles[i];
                }
                break;
            case WORLD_CAVE:
                for (int i = 0; i < sprites_caveTilesLen / 2; i++) {
                    vramObj[i] = sprites_caveTiles[i];
                }
                break;
            case WORLD_CITY:
                for (int i = 0; i < sprites_cityTilesLen / 2; i++) {
                    vramObj[i] = sprites_cityTiles[i];
                }
                break;
            case WORLD_MANSION:
                for (int i = 0; i < sprites_mansionTilesLen / 2; i++) {
                    vramObj[i] = sprites_mansionTiles[i];
                }
                break;
            case WORLD_POWERP:
                for (int i = 0; i < sprites_powerpTilesLen / 2; i++) {
                    vramObj[i] = sprites_powerpTiles[i];
                }
                break;
            case WORLD_WOODS:
                for (int i = 0; i < sprites_woodsTilesLen / 2; i++) {
                    vramObj[i] = sprites_woodsTiles[i];
                }
                break;
            case WORLD_SPECIAL1:
                for (int i = 0; i < sprites_special1TilesLen / 2; i++) {
                    vramObj[i] = sprites_special1Tiles[i];
                }
                break;
            case WORLD_SPECIAL2:
                for (int i = 0; i < sprites_special2TilesLen / 2; i++) {
                    vramObj[i] = sprites_special2Tiles[i];
                }
                break;
            default:
                for (int i = 0; i < sprites_labTilesLen / 2; i++) {
                    vramObj[i] = sprites_labTiles[i];
                }
                break;
        }
    }
}

// Set sprite attributes at specified index
void setSprite(int index, int x, int y, int tile) {
    u16 newAttr0 = (y & 0xFF) | (0 << 8) | (0 << 10) | (0 << 12) | (1 << 13);
    u16 newAttr1 = (x & 0x1FF) | (0 << 9) | (1 << 14);
    u16 newAttr2 = tile | (0 << 12);

    if (shadowOAM[index].attr0 != newAttr0 ||
        shadowOAM[index].attr1 != newAttr1 ||
        shadowOAM[index].attr2 != newAttr2) {
        shadowOAM[index].attr0 = newAttr0;
        shadowOAM[index].attr1 = newAttr1;
        shadowOAM[index].attr2 = newAttr2;
    }
}

// Hide sprite at specified index
void hideSprite(int index) {
    if (shadowOAM[index].attr0 != 0x0200) {
        shadowOAM[index].attr0 = 0x0200;
    }
}

// Set sprite attributes with priority
void setSpriteWithPriority(int index, int x, int y, int tile, int priority) {
    u16 newAttr0 = (y & 0xFF) | (0 << 8) | (0 << 10) | (0 << 12) | (1 << 13);
    u16 newAttr1 = (x & 0x1FF) | (0 << 9) | (1 << 14);
    u16 newAttr2 = tile | (priority << 10);

    if (shadowOAM[index].attr0 != newAttr0 ||
        shadowOAM[index].attr1 != newAttr1 ||
        shadowOAM[index].attr2 != newAttr2) {
        shadowOAM[index].attr0 = newAttr0;
        shadowOAM[index].attr1 = newAttr1;
        shadowOAM[index].attr2 = newAttr2;
    }
}

// Update OAM with shadow OAM contents
void updateOAM(void) {
    memcpy((void*)OAM, shadowOAM, sizeof(shadowOAM));
}

// Store current palettes for fading effects
void storeOriginalPalettes(void) {
    for (int i = 0; i < 16; i++) {
        originalBgPal[i] = BG_PAL[i];
    }
    for (int i = 0; i < 256; i++) {
        originalObjPal[i] = PAL_OBJ[i];
    }
    palettesStored = 1;
}

// Fade screen to black over specified steps
void fadeToBlack(int steps) {
    for (int step = 0; step <= steps; step++) {
        while (*(volatile u16*)0x04000006 >= 160);
        for (int i = 0; i < 16; i++) {
            u16 color = BG_PAL[i];
            u16 r = (color & 0x001F) * (steps - step) / steps;
            u16 g = ((color & 0x03E0) >> 5) * (steps - step) / steps;
            u16 b = ((color & 0x7C00) >> 10) * (steps - step) / steps;
            BG_PAL[i] = (r & 0x001F) | ((g & 0x001F) << 5) | ((b & 0x001F) << 10);
        }
        for (int i = 0; i < 256; i++) {
            u16 color = PAL_OBJ[i];
            u16 r = (color & 0x001F) * (steps - step) / steps;
            u16 g = ((color & 0x03E0) >> 5) * (steps - step) / steps;
            u16 b = ((color & 0x7C00) >> 10) * (steps - step) / steps;
            PAL_OBJ[i] = (r & 0x001F) | ((g & 0x001F) << 5) | ((b & 0x001F) << 10);
        }
        while (*(volatile u16*)0x04000006 < 160);
    }
}

// Fade screen from black to original palettes
void fadeFromBlack(int steps) {
    if (!palettesStored) return;
    for (int step = 0; step <= steps; step++) {
        while (*(volatile u16*)0x04000006 >= 160);
        for (int i = 0; i < 16; i++) {
            u16 bgColor = originalBgPal[i];
            u16 r = (bgColor & 0x001F) * step / steps;
            u16 g = ((bgColor & 0x03E0) >> 5) * step / steps;
            u16 b = ((bgColor & 0x7C00) >> 10) * step / steps;
            BG_PAL[i] = (r & 0x001F) | ((g & 0x001F) << 5) | ((b & 0x001F) << 10);
        }
        for (int i = 0; i < 256; i++) {
            u16 objColor = originalObjPal[i];
            u16 r = (objColor & 0x001F) * step / steps;
            u16 g = ((objColor & 0x03E0) >> 5) * step / steps;
            u16 b = ((objColor & 0x7C00) >> 10) * step / steps;
            PAL_OBJ[i] = (r & 0x001F) | ((g & 0x001F) << 5) | ((b & 0x001F) << 10);
        }
        while (*(volatile u16*)0x04000006 < 160);
    }
}

// Delay execution for specified frames
void delay(int frames) {
    for (int i = 0; i < frames; i++) {
        while (*(volatile u16*)0x04000006 >= 160);
        while (*(volatile u16*)0x04000006 < 160);
        mmFrame();
    }
}

// Fill background map with randomized tiles
void fillBackgroundWithTileset(const u16* tileSet1, const u16* tileSet2, const u16* tileSet3, const u16* tileSet4) {
    for (int row = 0; row < 32; row += 2) {
        for (int col = 0; col < 32; col += 2) {
            u16 r = rand() % 100;
            const u16* tileSet = (r < 50) ? tileSet1 : 
                                (r < 80) ? tileSet2 : 
                                (r < 95) ? tileSet3 : tileSet4;
            BG_MAP[row * 32 + col] = tileSet[0];
            BG_MAP[row * 32 + col + 1] = tileSet[1];
            BG_MAP[(row + 1) * 32 + col] = tileSet[2];
            BG_MAP[(row + 1) * 32 + col + 1] = tileSet[3];
        }
    }
}

// Fill background map with default tileset
void fillBackground(void) {
    fillBackgroundWithTileset(BG_TILE_1, BG_TILE_2, BG_TILE_3, BG_TILE_4);
}

// Update coin counter display
void updateCounter(int coinCount, int guiX, int guiY) {
    int tens = coinCount / 10, ones = coinCount % 10, counterBase = COUNTER_0;
    if (tens > 0) {
        setSpriteWithPriority(1, guiX + 4, guiY, counterBase + tens * 8, 0);
        setSpriteWithPriority(2, guiX + 12, guiY, counterBase + ones * 8, 0);
        activeSprites[1] = 1;
        activeSprites[2] = 1;
    } else {
        setSpriteWithPriority(1, guiX + SPRITE_SIZE / 2, guiY, counterBase + ones * 8, 0);
        hideSprite(2);
        activeSprites[1] = 1;
        activeSprites[2] = 0;
    }
}

// Reset and update sprites for level objects
void resetSprites(Wall* walls, int wallCount, LevelObject* objects, int objectCount, int guiX, int guiY, int coinCount) {
    int newActiveSpriteCount = 0;
    int newActiveSprites[128] = {0};
    int spriteIndex = 3;

    newActiveSprites[0] = 1;
    newActiveSpriteCount++;

    if (coinCount != previousCoinCount) {
        updateCounter(coinCount, guiX, guiY);
        newActiveSprites[1] = 1;
        newActiveSpriteCount++;
        if (coinCount / 10 > 0) {
            newActiveSprites[2] = 1;
            newActiveSpriteCount++;
        }
        previousCoinCount = coinCount;
    } else {
        newActiveSprites[1] = 1;
        newActiveSpriteCount++;
        if (previousOAM[2].attr0 != 0x0200) {
            newActiveSprites[2] = 1;
            newActiveSpriteCount++;
        }
    }

    for (int j = 0; j < wallCount && spriteIndex < 128; j++) {
        setSprite(spriteIndex, walls[j].x, walls[j].y, walls[j].tile);
        newActiveSprites[spriteIndex] = 1;
        newActiveSpriteCount++;
        spriteIndex++;
    }

    for (int j = 0; j < objectCount && spriteIndex < 128; j++) {
        if (!objects[j].collected || objects[j].tile == STOP || objects[j].tile == WARP || 
            objects[j].tile == RAMP_LEFT || objects[j].tile == RAMP_RIGHT || 
            objects[j].tile == POKEBALL || objects[j].tile == EXIT) {
            setSprite(spriteIndex, objects[j].x, objects[j].y, objects[j].tile);
            newActiveSprites[spriteIndex] = 1;
            newActiveSpriteCount++;
            spriteIndex++;
        }
    }

    for (int i = 0; i < 128; i++) {
        if (activeSprites[i] && !newActiveSprites[i] && shadowOAM[i].attr0 != 0x0200) {
            hideSprite(i);
        }
    }

    activeSpriteCount = 0;
    for (int i = 0; i < 128; i++) {
        activeSprites[i] = newActiveSprites[i];
        if (newActiveSprites[i]) {
            activeSpriteCount++;
        }
    }

    memcpy(previousOAM, shadowOAM, sizeof(shadowOAM));

    updateOAM();
}

// Play player spawn animation
void playSpawnAnimation(int x, int y, int* spawnTiles) {
    for (int i = 0; i < 6; i++) {
        setSprite(0, x, y, spawnTiles[i]);
        updateOAM();
        delay(5);
    }
    previousOAM[0] = shadowOAM[0];
    activeSprites[0] = 1;
    if (activeSpriteCount == 0 && activeSprites[0]) {
         activeSpriteCount = 1;
    } else if (!activeSprites[0]) {
        activeSpriteCount++;
    }
}

// Reset graphics for a new level
void resetGraphicsForLevel(WorldType worldType) {
    if (worldType != previousWorldType) {
        memset((void*)0x06000000, 0, 0xA000);
    }

    REG_DISPCNT = 0;
    initLevelGraphics();
    loadBackgroundTileset(worldType);
    loadLevelTileset(worldType);
    fillBackground();
    storeOriginalPalettes();
    
    previousWorldType = worldType;
}