import { useState, useEffect } from "react";
import styled from "styled-components";

import { getDatasetPreviewByDatasetId, getDatasetInfoByDatasetId } from "api/services/dataService";
import { isEmpty, last, round } from "lodash";
import ProgressBar from "components/ui/ProgressBar";
import MessageModalTrigger from "components/views/MessageModalTrigger";
import useTrainingProgressForJobId from "api/services/jobService/useTrainingProgressForJobId";
import { patchJobById } from "api/services/jobService";
import { postContinuePipelineTraining } from "api/services/projectService";
import { format } from "date-fns";
import { Area, AreaChart, CartesianGrid, Line, LineChart, ResponsiveContainer, XAxis, YAxis } from "recharts";
import { AccessTime, Train } from "@material-ui/icons";
import { ModelIcon, PlotIcon, TrainingIcon } from "components/ui/Icons";
import TrModal from "components/TrModal";
import { Gap } from "components/Layout";

const columnNameToDisplayName = {
  context: "Context",
  prompt: "Prompt",
  generatedText: "Target Output",
  "generatedText (predicted)": "GPT Output",
};

const disp = value => {
  return columnNameToDisplayName[value] || value;
};

const Container = styled.div`
  padding: 0 14px;
  background-color: ${props => props.theme.color.closer0_5};
  ${props => props.isDisabled && "opacity: 0.1; pointer-events: none;"}
`;

const Th = styled.th`
  text-align: left;
  font-weight: bold;
  padding: 20px 14px;
  border-bottom: 1px solid ${props => props.theme.color.closest};
`;

const Table = styled.table`
  table-layout: fixed;
`;

const StyledTrModal = styled(TrModal)`
  cursor: pointer;
  :hover {
    opacity: 0.5;
  }
`;

const Td = styled.td`
  white-space: nowrap;
  overflow: hidden;
  text-overflow: ellipsis;
  width: 25%;
  color: ${props => (props.isNull ? props.theme.color.closer1 : props.theme.color.closest)};
  padding: 14px 14px;
  ${props => props.isHighlighted && `background-color: ${props.theme.color.success}22;`}
`;

const CardsAndBar = styled.div`
  padding-top: 24px;
  padding-bottom: 24px;
  display: grid;
  gap: 8px;
  column-gap: 40px;
  grid-template-columns: auto auto auto;
  justify-content: center;
`;

const StyledProgressBar = styled(ProgressBar)`
  height: 2px;
  grid-area: bottom;

  opacity: 0;
  ${props => props.isVisible && "opacity: 1;"}
  transition: opacity 0.4s;
  border-radius: 0;
`;

const BarAndPercentage = styled.div`
  opacity: 0;
  ${props => props.isVisible && "opacity: 1;"}
  transition: opacity 0.2s;
  display: flex;
  gap: 4px;
  align-items: center;
  gap: 8px;
  grid-column: span 3;
`;

const safeDateFormat = (dateStr, formatStr) => {
  try {
    const d = new Date(dateStr);
    return format(d, formatStr) || "";
  } catch (e) {
    return "";
  }
};

const NumberOverlaySquare = styled.div`
  top: -12px;
  left: 4px;

  width: 40px;
  height: 40px;
  background-color: grey;
  box-shadow: ${props => props.theme.shadow};
  position: absolute;

  display: flex;
  justify-content: center;
  align-items: center;
  svg {
    height: 28px;
    fill: ${props => props.theme.color.furthest};
  }
`;

const NumberCard = styled.div`
  position: relative;
  background-color: ${props => props.theme.color.furthest};
  box-shadow: ${props => props.theme.shadow};
  padding: 12px;
  border: 4px solid;
  border-image: linear-gradient(
    ${props => props.theme.color.highlightGrey},
    ${props => props.theme.color.highlightGrey}
  );
  display: flex;
  flex-direction: column;
  gap: 12px;
  align-items: end;
  width: 240px;
`;

const CardTitle = styled.div`
  font-size: 14px;
  color: ${props => props.theme.color.closer2};
`;

const CardValue = styled.div`
  font-size: 18px;
  font-weight: 600;
`;

const PlotContainer = styled.div`
  width: 100%;
  height: 100%;
  .recharts-cartesian-grid {
    line {
      stroke: ${props => props.theme.color.closer1};
    }
  }
  .recharts-cartesian-axis-tick {
    text {
      fill: ${props => props.theme.color.closest};
    }
  }

  /* border-bottom: 1px dashed ${props => props.theme.color.closer1_5}; */
`;

const PlotCard = styled.div`
  position: relative;
  width: 240px;

  background-color: ${props => props.theme.color.furthest};
  border: 4px solid;
  border-image: linear-gradient(
    ${props => props.theme.color.highlightGrey},
    ${props => props.theme.color.highlightGrey}
  );
  box-shadow: ${props => props.theme.shadow};
`;

const ContextCard = styled.div`
  padding: 10px;
  max-width: 80%;
  background-color: ${props => props.theme.color.closer1};
  white-space: pre-wrap;
  line-height: 1.2;
  overflow-y: auto;
  height: 200px;
`;

const PreviewSectionTitle = styled.div`
  border-top: 1px solid lightgrey;
  font-size: 16px;
  width: 100%;
  font-weight: 600;
  padding: 8px 0;
`;

const MiniPlot = ({ iterationPoints }) => {
  const dataKeys = Object.keys(iterationPoints?.[0] || {});

  return (
    <PlotCard>
      <NumberOverlaySquare style={{ backgroundColor: "#0191ff" }}>
        <PlotIcon style={{ transform: "scale(0.7)" }} />
      </NumberOverlaySquare>
      <CardTitle style={{ backgroundColor: "white", zIndex: 1, position: "absolute", top: "12px", right: "12px" }}>
        Loss log
      </CardTitle>
      <PlotContainer>
        <ResponsiveContainer width="100%" height="100%">
          <AreaChart
            data={iterationPoints}
            margin={{
              top: 0,
              right: 0,
              left: 55,
              bottom: 0,
            }}
          >
            <defs>
              <linearGradient id="areaGradient" x1="0" y1="0" x2="0" y2="1">
                <stop offset="5%" stopColor="#0191ff" stopOpacity={1} />
                <stop offset="100%" stopColor="#0191ff" stopOpacity={0} />
              </linearGradient>
            </defs>
            {dataKeys.includes("loss") && (
              <Area
                isAnimationActive={false}
                type="monotone"
                dataKey="loss"
                name="Trainng loss"
                stroke="#0191ff"
                fill="url(#areaGradient)"
                activeDot={{ r: 8 }}
                strokeWidth={1}
                dot={false}
              />
            )}
          </AreaChart>
        </ResponsiveContainer>
      </PlotContainer>
    </PlotCard>
  );
};

const DatasetPreviewAnimated = ({
  pipelineOutput = {},
  datasetId,
  jobId,
  pipelineId,
  doRestartPipelineOutputPolling = () => {},
}) => {
  const [isLoading, setIsLoading] = useState(true);
  const [iterationPoints, trainingJob] = useTrainingProgressForJobId(jobId, pipelineId);

  const [modalInfo, setModalInfo] = useState({
    context: "",
    prompt: "",
    generatedText: "",
  });

  const doPauseJob = async () => {
    setIsLoading(true);
    await patchJobById(jobId, {}, { cancelRequested: true });
    doRestartPipelineOutputPolling();
  };

  const doResumeJob = async () => {
    setIsLoading(true);
    const { error } = await postContinuePipelineTraining(pipelineId);
    if (error) {
      alert("Error resuming training ", JSON.stringify(error));
    }
    setIsLoading(false);
    doRestartPipelineOutputPolling();
  };

  const outputCololumnNames = iterationPoints?.[0]?.outputColumns || [];
  const predictedColumnNames = iterationPoints?.[0]?.predictedColumns || [];
  const columnNames = iterationPoints?.[0]?.displayColumns || [];

  const lastLogItem = last(iterationPoints);

  return (
    <Container>
      <CardsAndBar>
        {trainingJob?.finishedAt ? (
          <NumberCard>
            <NumberOverlaySquare style={{ backgroundColor: "#0191ff" }}>
              <AccessTime />
            </NumberOverlaySquare>
            <CardTitle>Training finished on</CardTitle>
            <CardValue>{safeDateFormat(trainingJob?.finishedAt, "d MMM yyyy") || "-"}</CardValue>
          </NumberCard>
        ) : (
          <NumberCard>
            <NumberOverlaySquare style={{ backgroundColor: "#0191ff" }}>
              <AccessTime />
            </NumberOverlaySquare>
            <CardTitle>Training progress</CardTitle>
            <CardValue>{trainingJob?.progress}%</CardValue>
          </NumberCard>
        )}

        {trainingJob?.status === "FAILED" ? (
          <NumberCard>
            <NumberOverlaySquare style={{ backgroundColor: "#9650ff" }} />
            <CardTitle>Training Status</CardTitle>
            <CardValue>{trainingJob?.status}</CardValue>
          </NumberCard>
        ) : (
          <NumberCard>
            <NumberOverlaySquare style={{ backgroundColor: "#9650ff" }}>
              <ModelIcon style={{ transform: "scale(0.75)" }} />
            </NumberOverlaySquare>
            <CardTitle>Last {lastLogItem?.accuracy ? "training accuracy" : "validation loss"}</CardTitle>
            <CardValue>{round(lastLogItem?.accuracy || lastLogItem?.validationLoss, 2) || "-"}</CardValue>
          </NumberCard>
        )}

        <MiniPlot iterationPoints={iterationPoints} />
      </CardsAndBar>

      <BarAndPercentage isVisible>
        <StyledProgressBar isVisible currentValue={trainingJob?.progress || 0} maxValue={100} />
      </BarAndPercentage>

      <Table>
        <thead>
          <tr>
            {!isEmpty(columnNames) &&
              columnNames.map(columnName => {
                if (outputCololumnNames?.includes(columnName)) {
                  return (
                    <Th style={{ color: "#0191ff" }} key={columnName}>
                      {disp(columnName)}
                    </Th>
                  );
                }

                if (predictedColumnNames?.includes(columnName)) {
                  return (
                    <Th style={{ color: "#9650ff" }} key={columnName}>
                      {disp(columnName)}
                    </Th>
                  );
                }

                return <Th key={columnName}>{disp(columnName)}</Th>;
              })}
          </tr>
        </thead>
        <tbody>
          {last(iterationPoints)?.validationDatasetPreview?.map((dataPoint, rowIndex) => (
            <StyledTrModal
              key={rowIndex}
              onClick={() => {
                setModalInfo({
                  context: dataPoint?.[columnNames?.[0]],
                  prompt: dataPoint?.[columnNames?.[1]],
                });
              }}
              modalContent={
                <div>
                  <PreviewSectionTitle style={{ marginTop: "-10px", borderTop: "none" }}>Context</PreviewSectionTitle>
                  <ContextCard>{modalInfo?.context}</ContextCard>
                  <Gap height="40px" />
                  <PreviewSectionTitle>Prompt</PreviewSectionTitle>
                  <div>{modalInfo?.prompt}</div>
                  <Gap height="40px" />
                  <table style={{ borderTop: "1px solid lightgrey" }}>
                    <thead>
                      <tr>
                        {!isEmpty(columnNames) &&
                          columnNames.map(columnName => {
                            if (outputCololumnNames?.includes(columnName)) {
                              return (
                                <Th style={{ color: "#0191ff" }} key={columnName}>
                                  {disp(columnName)}
                                </Th>
                              );
                            }

                            if (predictedColumnNames?.includes(columnName)) {
                              return (
                                <Th style={{ color: "#9650ff" }} key={columnName}>
                                  {disp(columnName)}
                                </Th>
                              );
                            }

                            return null;
                          })}
                      </tr>
                    </thead>
                    <tbody>
                      <tr>
                        {[...outputCololumnNames, ...predictedColumnNames].map((columnName, colIndex) => {
                          const val = Array.isArray(dataPoint[columnName])
                            ? dataPoint[columnName].join(" ")
                            : dataPoint[columnName];

                          return (
                            <Td
                              style={{ whiteSpace: "normal", overflow: "visible" }}
                              isHighlighted={predictedColumnNames?.includes(columnName)}
                              isNull={val === null}
                              key={`${rowIndex}-${val}-${colIndex}`}
                            >
                              {typeof val === "number" ? round(val, 2) : val}
                            </Td>
                          );
                        })}
                      </tr>
                    </tbody>
                  </table>
                </div>
              }
            >
              {columnNames.map((columnName, colIndex) => {
                const val = Array.isArray(dataPoint[columnName])
                  ? dataPoint[columnName].join(" ")
                  : dataPoint[columnName];

                return (
                  <Td
                    isHighlighted={predictedColumnNames?.includes(columnName)}
                    isNull={val === null}
                    key={`${rowIndex}-${val}-${colIndex}`}
                  >
                    {typeof val === "number" ? round(val, 2) : val}
                  </Td>
                );
              })}
            </StyledTrModal>
          ))}
        </tbody>
      </Table>
    </Container>
  );
};

export default DatasetPreviewAnimated;
