library(ggplot2)
library(dplyr)

# 1. Set your Cobb–Douglas exponent
alpha <- 0.3

# 2. Create a high-res grid
grid <- expand.grid(
  x = seq(0.1, 22, length.out = 300),
  y = seq(0.1, 11, length.out = 300)
)
grid$U <- grid$x^alpha * grid$y^(1 - alpha)

# 3. Anchor points — placed in the correct colour regions of the heatmap
#    U = x^0.3 * y^0.7, so:
#      top-right    (20, 10) ≈ bright yellow  (highest utility)
#      top-left     ( 1, 10) ≈ orange         (some but not much — few x, many y)
#      middle       ( 9,  4) ≈ pink/purple    (medium utility)
#      bottom-right (20,  1) ≈ dark purple    (low utility — many x, few y)
points_df <- data.frame(
  px    = c(20,    1,    9,   20),
  py    = c(10,   10,    4,    1),
  # Label positions: nudged close to each point
  tx    = c(17,   3.5,   9,  17),
  ty    = c( 8.8,  9.3,  4.7,  1.8),
  label = c(
    "Having many ice cream scoops\nand many burritos\ngives the\nhighest utility",
    "Having few ice cream scoops\nbut many burritos\ngives some,\nbut not much utility",
    "Having some ice cream scoops\nand some burritos\ngives neither the highest\nnor the lowest utility",
    "Having many ice cream scoops\nbut few burritos\ngives\nlow utility"
  )
)

# Textbook-ready plot
textbook_plot <- function() {
  ggplot(grid, aes(x = x, y = y, fill = U)) +
    # Background heatmap
    geom_tile() +
    scale_fill_viridis_c(
      option = "C",
      name = "Least preferred                      Most preferred",
      guide = guide_colorbar(
        title.position = "top",
        barwidth = 20,
        ticks = FALSE,
        label = FALSE
      )
    ) +

    # Axis formatting
    scale_x_continuous(
      breaks = seq(0, 22, by = 4),
      limits = c(0, 22),
      expand = c(0, 0)
    ) +
    scale_y_continuous(
      breaks = seq(0, 11, by = 2),
      limits = c(0, 11),
      expand = c(0, 0)
    ) +

    # Axis labels
    labs(
      x = "Ice Cream Scoops (x)",
      y = "Burritos (y)",
      title = ""
    ) +

    # Theme
    theme_minimal() +
    theme(
      plot.title       = element_text(hjust = 0.5, size = 18, face = "bold"),
      axis.title       = element_text(size = 14, face = "bold"),
      axis.text        = element_text(size = 12),
      legend.position  = "top",
      legend.title     = element_text(size = 14, face = "bold"),
      legend.margin    = margin(b = 10),
      plot.margin      = margin(10, 20, 10, 10)
    ) +

    # White anchor points
    geom_point(
      data        = points_df,
      aes(x = px, y = py),
      inherit.aes = FALSE,
      color       = "white",
      size        = 4,
      stroke      = 1
    ) +

    # Labels positioned near their anchor points
    geom_text(
      data        = points_df,
      aes(x = tx, y = ty, label = label),
      inherit.aes = FALSE,
      color       = "white",
      size        = 4.5,
      fontface    = "bold",
      lineheight  = 0.95
    )
}

# Save high-resolution plot
ggsave(
  "Ch09/figures/heatmap_pref_labeled.png",
  plot  = textbook_plot(),
  width = 10,
  height = 7,
  dpi   = 300,
  bg    = "white"
)

# For PDF output (vector graphics, better for publishing)
ggsave(
  "Ch09/figures/heatmap_pref_labeled.pdf",
  plot  = textbook_plot(),
  width = 10,
  height = 7,
  bg    = "white"
)
