import {useMutation, useQueryClient} from 'react-query'
import {useRecoilValue} from 'recoil'
import {reportApi} from '../../api-interface'
import {GetReportViewersResponse, Group, User} from '../../models'
import {licenseAtom} from '../../state'
import {
  getGroupsKey,
  getOtherUsersOnLicenseKey,
  getViewersKey,
} from '../queries'

type AddViewersPayload = {
  reportId: string
  viewerIds: string[]
  viewerType: 'user' | 'group'
}

/** Add viewers to a PBI report */
export function useAddViewersToReport() {
  const licenseId = useRecoilValue(licenseAtom)
  const queryClient = useQueryClient()

  return useMutation(
    (payload: AddViewersPayload) =>
      reportApi.addViewers(payload.reportId, payload.viewerIds, licenseId),
    {
      async onMutate(
        payload: AddViewersPayload
      ): Promise<GetReportViewersResponse> {
        await queryClient.cancelQueries([getViewersKey, payload.reportId])

        const allUsers: User[] = queryClient.getQueryData([
          getOtherUsersOnLicenseKey,
          licenseId,
        ])

        const allGroups: Group[] = queryClient.getQueryData([
          getGroupsKey,
          licenseId,
        ])

        const previousViewers: GetReportViewersResponse =
          queryClient.getQueryData([getViewersKey, payload.reportId])

        queryClient.setQueryData(
          [getViewersKey, payload.reportId],
          (old: GetReportViewersResponse) => {
            if (payload.viewerType === 'user') {
              if(!old?.users) {
                old = {...old, users: []}
              }
              old.users = allUsers
                .filter((u: User) => payload.viewerIds.includes(u.id))
                .concat(previousViewers?.users || [])
                .sort((a: User, b: User) =>
                  a.firstName.localeCompare(b.firstName)
                )
            }
            if (payload.viewerType === 'group') {
              if(!old?.groups){
                old = {...old, groups: []}
              }
              old.groups = allGroups
                .filter((u: Group) => payload.viewerIds.includes(u.id))
                .concat(previousViewers?.groups || [])
                .sort((a: Group, b: Group) => a.name.localeCompare(b.name))
            }

            return old
          }
        )

        return previousViewers
      },
      onError(
        _err,
        payload: AddViewersPayload,
        context: GetReportViewersResponse
      ) {
        queryClient.setQueryData([getViewersKey, payload.reportId], context)
      },
    }
  )
}
