import { inject, injectable } from 'inversify'
import { Centrifuge, UnauthorizedError } from 'centrifuge'
import { SERVICE_TYPES } from '@/core/container/types'
import type { Config } from '@/core/types'
import type LoggerService from '@/services/loggerService'
import type {
  EventToPayloadAbstractMap,
  HandlerOfAllEvents,
  ITypedWebSocketService,
  SocketEventsUnionFromMap,
  SocketTokenGettingCallback,
} from '@/services/transport/types'
import type { MonitoringServiceInterface } from '@/services/monitoring/types'
import TmWebSocketConnectionError from '@/core/error/socket/TmWebSocketConnectionError'
import TmLogicError from '@/core/error/tmLogicError'

@injectable()
export default class WebSocketChannelsService<EventMap extends EventToPayloadAbstractMap>
  implements ITypedWebSocketService<EventMap>
{
  private connection: Centrifuge | null = null

  private connectionUrl: string

  private defaultChannelListeners: HandlerOfAllEvents<EventMap>[] = []

  private channelListeners: { [channel: string]: HandlerOfAllEvents<EventMap>[] } = {}

  private statusChangeListeners: ((isConnected: boolean) => void)[] = []

  private getTokens: SocketTokenGettingCallback = () => {
    throw new TmLogicError('Use setTokenGettingCallback first')
  }

  constructor(
    @inject(SERVICE_TYPES.Config) protected readonly config: Config,
    @inject(SERVICE_TYPES.LoggerService) protected readonly loggerService: LoggerService,
    @inject(SERVICE_TYPES.MonitoringService) protected readonly monitoringService: MonitoringServiceInterface,
  ) {
    this.connectionUrl = config.socketUrl
  }

  public setTokenGettingCallback(callback: SocketTokenGettingCallback) {
    this.getTokens = callback
  }

  public async tryConnect() {
    if (this.connection?.state === 'connected') {
      return this.connection
    }
    if (this.connection) {
      this.connection.connect()
      return this.connection
    }
    const tokens = await this.getTokens({})
    this.connection = new Centrifuge(this.connectionUrl, {
      token: tokens.connectionToken,
      getToken: async () => {
        const { connectionToken } = await this.getTokens({ isTokenRefreshing: true })
        if (!connectionToken) {
          throw new UnauthorizedError('No token provided')
        }
        return connectionToken
      },
    })
    const sub = this.connection.newSubscription(tokens.subscriptionTokens[0].channel, {
      token: tokens.subscriptionTokens[0].subscriptionToken,
      getToken: async () => {
        const { subscriptionTokens } = await this.getTokens({ isTokenRefreshing: true })
        if (!subscriptionTokens.length || !subscriptionTokens[0].subscriptionToken) {
          throw new UnauthorizedError('No token provided')
        }
        return subscriptionTokens[0].subscriptionToken
      },
    })
    sub.subscribe()
    this.connection.connect()
    this.connection.once('connected', () => {
      this.log(`Web Socket connected with config: ${JSON.stringify(this.connectionUrl)}`)
    })
    this.connection.on('error', (ctx) => {
      // error is temporary muted, waiting research from BE team aroung centrifugo connection logs
      // this.monitoringService.logError(new Error(ctx.error.message))
      this.logError('Web Socket error.')
      this.logError(ctx.error?.message)
    })
    this.connection.on('disconnected', () => {
      this.statusChangeListeners.forEach((cb) => cb(false))
    })
    this.connection.on('connected', () => {
      this.statusChangeListeners.forEach((cb) => cb(true))
    })
    sub.on('publication', (ctx: SocketEventsUnionFromMap<EventMap>) => {
      const channelCallbacks = this.channelListeners[ctx.channel]
      if (channelCallbacks) {
        channelCallbacks.forEach((cb) => cb(ctx.data))
      } else {
        this.defaultChannelListeners.forEach((cb) => cb(ctx.data))
      }
    })
    return this.connection
  }

  public async connect() {
    try {
      await this.tryConnect()
    } catch (e) {
      let errorMessage: string
      if (e instanceof Error) {
        this.logError('Cannot connect to websocket')
        this.logError(e.message)
        this.logError(`${e.stack}`)
        errorMessage = e.toString()
      } else {
        errorMessage = `${e}`
      }
      this.monitoringService.logInfo(`Cannot connect to websocket. ${errorMessage}`)
    }
  }

  public disconnect() {
    try {
      const connection = this.getConnection()
      if (connection.state === 'disconnected') {
        return
      }
      connection.disconnect()
      connection.removeAllListeners()
      this.connection = null
    } catch (e) {
      if (e instanceof TmWebSocketConnectionError) {
        return
      }
      throw e
    }
  }

  public async addEventListener(channel: string, handler: HandlerOfAllEvents<EventMap>) {
    let channelCallbacks = this.channelListeners[channel]
    if (!channelCallbacks) {
      channelCallbacks = []
      this.channelListeners[channel] = channelCallbacks
    }
    if (channelCallbacks.indexOf(handler) === -1) {
      channelCallbacks.push(handler)
    }
    if (!this.isConnected()) {
      this.connect()
    }
  }

  public removeEventListener(channel: string, handler: HandlerOfAllEvents<EventMap>) {
    const channelCallbacks = this.channelListeners[channel]
    if (!channelCallbacks) {
      return
    }
    const index = channelCallbacks.indexOf(handler)
    if (index >= 0) {
      channelCallbacks.splice(index, 1)
    }
  }

  public async publish(channel: string, data: unknown): Promise<void> {
    const connection = await this.getConnectionSafe()
    await connection.publish(channel, data)
  }

  public removeAllEventListener(channel: string): void {
    delete this.channelListeners[channel]
  }

  public addDefaultChannelListener(handler: HandlerOfAllEvents<EventMap>): void {
    if (this.defaultChannelListeners.indexOf(handler) === -1) {
      this.defaultChannelListeners.push(handler)
    }
    if (!this.isConnected()) {
      this.connect()
    }
  }

  public removeDefaultChannelListener(handler: HandlerOfAllEvents<EventMap>): void {
    const index = this.defaultChannelListeners.indexOf(handler)
    if (index >= 0) {
      this.defaultChannelListeners.splice(index, 1)
    }
  }

  public removeAllDefaultChannelListener() {
    this.defaultChannelListeners = []
  }

  public addStatusChangeListener(handler: (isConnected: boolean) => void) {
    if (this.statusChangeListeners.indexOf(handler) === -1) {
      this.statusChangeListeners.push(handler)
      handler(this.connection?.state === 'connected')
    }
  }

  public removeStatusChangeListener(handler: (isConnected: boolean) => void) {
    const index = this.statusChangeListeners.indexOf(handler)
    if (index >= 0) {
      this.statusChangeListeners.splice(index, 1)
    }
  }

  public isConnected() {
    if (!this.connection) {
      return false
    }
    return this.connection.state === 'connected'
  }

  private tryConnectPromise: Promise<Centrifuge> | null = null

  protected async getConnectionSafe() {
    if (this.connection) {
      return this.connection
    }
    if (!this.tryConnectPromise) {
      this.tryConnectPromise = this.tryConnect().finally(() => {
        this.tryConnectPromise = null
      })
    }
    return this.tryConnectPromise
  }

  protected getConnection() {
    if (!this.connection) {
      throw new TmWebSocketConnectionError('WebSocket connection not established')
    }
    return this.connection
  }

  protected log(message: string) {
    if (this.loggerService.shouldLogByChannel('bus', ['socket'])) {
      this.loggerService.log('bus', message, 'socket')
    }
  }

  protected logError(message: string) {
    if (this.loggerService.shouldLogByChannel('bus', ['socket'])) {
      this.loggerService.error('bus', message, 'socket')
    }
  }
}
