use crate::ast::*;
use serde::{Deserialize, Serialize, Serializer};
#[derive(Debug, Clone)]
pub enum ScalarError {
    IncorrectSign,
    OutOfBounds,
}
pub type ScalarResult<T> = std::result::Result<T, ScalarError>;
impl ScalarValue {
    pub fn get_integer_ty(&self) -> IntegerTy {
        match self {
            ScalarValue::Isize(_) => IntegerTy::Isize,
            ScalarValue::I8(_) => IntegerTy::I8,
            ScalarValue::I16(_) => IntegerTy::I16,
            ScalarValue::I32(_) => IntegerTy::I32,
            ScalarValue::I64(_) => IntegerTy::I64,
            ScalarValue::I128(_) => IntegerTy::I128,
            ScalarValue::Usize(_) => IntegerTy::Usize,
            ScalarValue::U8(_) => IntegerTy::U8,
            ScalarValue::U16(_) => IntegerTy::U16,
            ScalarValue::U32(_) => IntegerTy::U32,
            ScalarValue::U64(_) => IntegerTy::U64,
            ScalarValue::U128(_) => IntegerTy::U128,
        }
    }
    pub fn is_int(&self) -> bool {
        matches!(
            self,
            ScalarValue::Isize(_)
                | ScalarValue::I8(_)
                | ScalarValue::I16(_)
                | ScalarValue::I32(_)
                | ScalarValue::I64(_)
                | ScalarValue::I128(_)
        )
    }
    pub fn is_uint(&self) -> bool {
        matches!(
            self,
            ScalarValue::Usize(_)
                | ScalarValue::U8(_)
                | ScalarValue::U16(_)
                | ScalarValue::U32(_)
                | ScalarValue::U64(_)
                | ScalarValue::U128(_)
        )
    }
    pub fn as_uint(&self) -> ScalarResult<u128> {
        match self {
            ScalarValue::Usize(v) => Ok(*v as u128),
            ScalarValue::U8(v) => Ok(*v as u128),
            ScalarValue::U16(v) => Ok(*v as u128),
            ScalarValue::U32(v) => Ok(*v as u128),
            ScalarValue::U64(v) => Ok(*v as u128),
            ScalarValue::U128(v) => Ok(*v),
            _ => Err(ScalarError::IncorrectSign),
        }
    }
    pub fn uint_is_in_bounds(ty: IntegerTy, v: u128) -> bool {
        match ty {
            IntegerTy::Usize => v <= (usize::MAX as u128),
            IntegerTy::U8 => v <= (u8::MAX as u128),
            IntegerTy::U16 => v <= (u16::MAX as u128),
            IntegerTy::U32 => v <= (u32::MAX as u128),
            IntegerTy::U64 => v <= (u64::MAX as u128),
            IntegerTy::U128 => true,
            _ => false,
        }
    }
    pub fn from_unchecked_uint(ty: IntegerTy, v: u128) -> ScalarValue {
        match ty {
            IntegerTy::Usize => ScalarValue::Usize(v as u64),
            IntegerTy::U8 => ScalarValue::U8(v as u8),
            IntegerTy::U16 => ScalarValue::U16(v as u16),
            IntegerTy::U32 => ScalarValue::U32(v as u32),
            IntegerTy::U64 => ScalarValue::U64(v as u64),
            IntegerTy::U128 => ScalarValue::U128(v),
            _ => panic!("Expected an unsigned integer kind"),
        }
    }
    pub fn from_uint(ty: IntegerTy, v: u128) -> ScalarResult<ScalarValue> {
        if !ScalarValue::uint_is_in_bounds(ty, v) {
            trace!("Not in bounds for {:?}: {}", ty, v);
            Err(ScalarError::OutOfBounds)
        } else {
            Ok(ScalarValue::from_unchecked_uint(ty, v))
        }
    }
    pub fn as_int(&self) -> ScalarResult<i128> {
        match self {
            ScalarValue::Isize(v) => Ok(*v as i128),
            ScalarValue::I8(v) => Ok(*v as i128),
            ScalarValue::I16(v) => Ok(*v as i128),
            ScalarValue::I32(v) => Ok(*v as i128),
            ScalarValue::I64(v) => Ok(*v as i128),
            ScalarValue::I128(v) => Ok(*v),
            _ => Err(ScalarError::IncorrectSign),
        }
    }
    pub fn int_is_in_bounds(ty: IntegerTy, v: i128) -> bool {
        match ty {
            IntegerTy::Isize => v >= (isize::MIN as i128) && v <= (isize::MAX as i128),
            IntegerTy::I8 => v >= (i8::MIN as i128) && v <= (i8::MAX as i128),
            IntegerTy::I16 => v >= (i16::MIN as i128) && v <= (i16::MAX as i128),
            IntegerTy::I32 => v >= (i32::MIN as i128) && v <= (i32::MAX as i128),
            IntegerTy::I64 => v >= (i64::MIN as i128) && v <= (i64::MAX as i128),
            IntegerTy::I128 => true,
            _ => false,
        }
    }
    pub fn from_unchecked_int(ty: IntegerTy, v: i128) -> ScalarValue {
        match ty {
            IntegerTy::Isize => ScalarValue::Isize(v as i64),
            IntegerTy::I8 => ScalarValue::I8(v as i8),
            IntegerTy::I16 => ScalarValue::I16(v as i16),
            IntegerTy::I32 => ScalarValue::I32(v as i32),
            IntegerTy::I64 => ScalarValue::I64(v as i64),
            IntegerTy::I128 => ScalarValue::I128(v),
            _ => panic!("Expected a signed integer kind"),
        }
    }
    pub fn from_le_bytes(ty: IntegerTy, b: [u8; 16]) -> ScalarValue {
        match ty {
            IntegerTy::Isize => {
                let b: [u8; 8] = b[0..8].try_into().unwrap();
                ScalarValue::Isize(i64::from_le_bytes(b))
            }
            IntegerTy::I8 => {
                let b: [u8; 1] = b[0..1].try_into().unwrap();
                ScalarValue::I8(i8::from_le_bytes(b))
            }
            IntegerTy::I16 => {
                let b: [u8; 2] = b[0..2].try_into().unwrap();
                ScalarValue::I16(i16::from_le_bytes(b))
            }
            IntegerTy::I32 => {
                let b: [u8; 4] = b[0..4].try_into().unwrap();
                ScalarValue::I32(i32::from_le_bytes(b))
            }
            IntegerTy::I64 => {
                let b: [u8; 8] = b[0..8].try_into().unwrap();
                ScalarValue::I64(i64::from_le_bytes(b))
            }
            IntegerTy::I128 => {
                let b: [u8; 16] = b[0..16].try_into().unwrap();
                ScalarValue::I128(i128::from_le_bytes(b))
            }
            IntegerTy::Usize => {
                let b: [u8; 8] = b[0..8].try_into().unwrap();
                ScalarValue::Usize(u64::from_le_bytes(b))
            }
            IntegerTy::U8 => {
                let b: [u8; 1] = b[0..1].try_into().unwrap();
                ScalarValue::U8(u8::from_le_bytes(b))
            }
            IntegerTy::U16 => {
                let b: [u8; 2] = b[0..2].try_into().unwrap();
                ScalarValue::U16(u16::from_le_bytes(b))
            }
            IntegerTy::U32 => {
                let b: [u8; 4] = b[0..4].try_into().unwrap();
                ScalarValue::U32(u32::from_le_bytes(b))
            }
            IntegerTy::U64 => {
                let b: [u8; 8] = b[0..8].try_into().unwrap();
                ScalarValue::U64(u64::from_le_bytes(b))
            }
            IntegerTy::U128 => {
                let b: [u8; 16] = b[0..16].try_into().unwrap();
                ScalarValue::U128(u128::from_le_bytes(b))
            }
        }
    }
    pub fn to_bits(&self) -> u128 {
        match *self {
            ScalarValue::Usize(v) => v as u128,
            ScalarValue::U8(v) => v as u128,
            ScalarValue::U16(v) => v as u128,
            ScalarValue::U32(v) => v as u128,
            ScalarValue::U64(v) => v as u128,
            ScalarValue::U128(v) => v,
            ScalarValue::Isize(v) => v as usize as u128,
            ScalarValue::I8(v) => v as u8 as u128,
            ScalarValue::I16(v) => v as u16 as u128,
            ScalarValue::I32(v) => v as u32 as u128,
            ScalarValue::I64(v) => v as u64 as u128,
            ScalarValue::I128(v) => v as u128,
        }
    }
    pub fn from_bits(ty: IntegerTy, bits: u128) -> Self {
        Self::from_le_bytes(ty, bits.to_le_bytes())
    }
    pub fn from_int(ty: IntegerTy, v: i128) -> ScalarResult<ScalarValue> {
        if !ScalarValue::int_is_in_bounds(ty, v) {
            Err(ScalarError::OutOfBounds)
        } else {
            Ok(ScalarValue::from_unchecked_int(ty, v))
        }
    }
    pub fn to_constant(self) -> ConstantExpr {
        ConstantExpr {
            value: RawConstantExpr::Literal(Literal::Scalar(self)),
            ty: TyKind::Literal(LiteralTy::Integer(self.get_integer_ty())).into_ty(),
        }
    }
}
impl Serialize for ScalarValue {
    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let enum_name = "ScalarValue";
        let variant_name = self.variant_name();
        let (variant_index, _variant_arity) = self.variant_index_arity();
        let v = match self {
            ScalarValue::Isize(i) => i.to_string(),
            ScalarValue::I8(i) => i.to_string(),
            ScalarValue::I16(i) => i.to_string(),
            ScalarValue::I32(i) => i.to_string(),
            ScalarValue::I64(i) => i.to_string(),
            ScalarValue::I128(i) => i.to_string(),
            ScalarValue::Usize(i) => i.to_string(),
            ScalarValue::U8(i) => i.to_string(),
            ScalarValue::U16(i) => i.to_string(),
            ScalarValue::U32(i) => i.to_string(),
            ScalarValue::U64(i) => i.to_string(),
            ScalarValue::U128(i) => i.to_string(),
        };
        serializer.serialize_newtype_variant(enum_name, variant_index, variant_name, &v)
    }
}
impl<'de> Deserialize<'de> for ScalarValue {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        struct Visitor;
        impl<'de> serde::de::Visitor<'de> for Visitor {
            type Value = ScalarValue;
            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
                write!(f, "ScalarValue")
            }
            fn visit_map<A: serde::de::MapAccess<'de>>(
                self,
                mut map: A,
            ) -> Result<Self::Value, A::Error> {
                use serde::de::Error;
                let (k, v): (String, String) = map.next_entry()?.expect("Malformed ScalarValue");
                Ok(match k.as_str() {
                    "Isize" => ScalarValue::Isize(v.parse().unwrap()),
                    "I8" => ScalarValue::I8(v.parse().unwrap()),
                    "I16" => ScalarValue::I16(v.parse().unwrap()),
                    "I32" => ScalarValue::I32(v.parse().unwrap()),
                    "I64" => ScalarValue::I64(v.parse().unwrap()),
                    "I128" => ScalarValue::I128(v.parse().unwrap()),
                    "Usize" => ScalarValue::Usize(v.parse().unwrap()),
                    "U8" => ScalarValue::U8(v.parse().unwrap()),
                    "U16" => ScalarValue::U16(v.parse().unwrap()),
                    "U32" => ScalarValue::U32(v.parse().unwrap()),
                    "U64" => ScalarValue::U64(v.parse().unwrap()),
                    "U128" => ScalarValue::U128(v.parse().unwrap()),
                    _ => {
                        return Err(A::Error::custom(format!(
                            "{k} is not a valid type for a ScalarValue"
                        )))
                    }
                })
            }
        }
        deserializer.deserialize_map(Visitor)
    }
}