import React, { useState } from "react";
import { useQuery } from "react-query";
import {
  Box, Skeleton, Alert, Table, TableHead, TableBody, TableRow, TableCell,
} from "@mui/material";
import ButtonTabs from "@/ui/atoms/ButtonTabs";
import dayjs from "dayjs";
import { styled } from "@mui/system";
import { getActivityFeedTaskMetrics } from "@/api/Metrics";

type TaskDiffProps = {
  taskIds: string[];
};

const StickyCell = styled(TableCell)(({ theme }) => ({
  position: "sticky",
  left: 0,
  backgroundColor: theme.palette.background.default,
  zIndex: 1,
  fontWeight: "bold",
}));

const UpdatedCell = styled(TableCell)(({ theme }) => ({
  backgroundColor: theme.palette.warning.main,
  color: theme.palette.warning.contrastText,
  fontWeight: "bold",
  filter: "saturate(0.7) brightness(1.2)",
}));

const NewCell = styled(TableCell)(({ theme }) => ({
  backgroundColor: theme.palette.success.main,
  color: theme.palette.success.contrastText,
  filter: "saturate(0.7) brightness(1.2)",
  fontWeight: "bold",
}));

export default function TaskDiff({
  taskIds,
}: TaskDiffProps) {
  const [selectedPeriod, setSelectedPeriod] = useState(null);

  const {
    data,
    isLoading,
    isError,
  } = useQuery(["activityFeed", "diff", ...taskIds], () => getActivityFeedTaskMetrics(taskIds), {
    refetchInterval: 0,
    refetchOnWindowFocus: false,
  });

  if (isLoading) {
    return (
      <Box
        py={1}
        display="flex"
        flexDirection="column"
        gap={1}
      >
        <Skeleton variant="rounded" width="100%" height={24} />
        <Skeleton variant="rounded" width="100%" height={24} />
        <Skeleton variant="rounded" width="100%" height={24} />
      </Box>
    );
  }

  if (isError) {
    return (
      <Box>
        <Alert severity="error">
          Error loading metrics. Please refresh the page or contact the labs team.
        </Alert>
      </Box>
    );
  }

  // group data by label name + period
  const groupedData = data.reduce((acc, metric) => {
    const key = `${metric.label.name}-${metric.periodName}`;
    acc[key] = metric;
    return acc;
  }, {});

  // get axes
  const periodSort = {
    MONTHLY: 1,
    QUARTERLY: 2,
    YEARLY: 3,
  };
  const periodTypes = [...new Set(data.map((metric) => metric.periodType))]
    .toSorted((a, b) => periodSort[a] - periodSort[b]);
  const periods = [...new Set(
    data
      .filter((x) => x.periodType === selectedPeriod)
      .map((metric) => metric.periodName),
  )]
    .toSorted((a, b) => {
      if (selectedPeriod === "QUARTERLY") {
        // sort by quarterly period
        const aQuarter = a.split(" ")[0];
        const bQuarter = b.split(" ")[0];
        const aYear = a.split(" ")[1];
        const bYear = b.split(" ")[1];
        if (aYear === bYear) {
          return bQuarter.localeCompare(aQuarter);
        }
        return Number(bYear) - Number(aYear);
      }
      return dayjs(b).diff(dayjs(a));
    });
  const labels = [...new Set(data.map((metric) => metric.label.name))]
    .toSorted();

  if (!selectedPeriod && periodTypes.length) {
    setSelectedPeriod(periodTypes[0]);
  }

  return (
    <Box>
      <Box
        my={1}
      >
        <ButtonTabs
          width={200 * periodTypes.length}
          options={periodTypes}
          onClick={setSelectedPeriod}
          activeKey={selectedPeriod}
        />
        {data.length === 0 && (
          <Alert severity="error">
            No metrics were added for this task.
          </Alert>
        )}
        {data.length > 0 && (
          <Box
            display="block"
            maxWidth="100%"
            overflow="auto"
          >
            <Table>
              <TableHead>
                <TableRow>
                  <TableCell />
                  {periods.map((period) => (
                    <TableCell key={period}>{period}</TableCell>
                  ))}
                </TableRow>
              </TableHead>
              <TableBody>
                {labels.map((label) => (
                  <TableRow key={label}>
                    <StickyCell>{label}</StickyCell>
                    {periods.map((period) => {
                      const dataId = `${label}-${period}`;
                      const metric = groupedData[dataId];
                      const prevValue = metric?.previousValues[0]?.value;
                      if (prevValue) {
                        return (
                          <UpdatedCell key={dataId}>
                            {`${metric.previousValues[0].value.toLocaleString()} > `}
                            {metric.value.toLocaleString()}
                          </UpdatedCell>
                        );
                      }
                      if (metric?.value) {
                        return (
                          <NewCell key={dataId}>
                            {metric.value.toLocaleString()}
                          </NewCell>
                        );
                      }
                      return <TableCell key={dataId} />;
                    })}
                  </TableRow>
                ))}
              </TableBody>
            </Table>
          </Box>
        )}
      </Box>
    </Box>
  );
}
