1use core::fmt;
2use std::collections::{BTreeMap, HashMap};
3use std::fmt::Display;
4use std::marker::PhantomData;
5
6use num_traits::{Float, FromPrimitive, ToPrimitive};
7use serde::Serialize;
8use serde::de::Error;
9
10use super::{DeserializeContent, DeserializeHelper, Expected, Tracker, TrackerDeserializer, TrackerFor};
11
12pub struct FloatWithNonFinTracker<T>(PhantomData<T>);
13
14impl<T> fmt::Debug for FloatWithNonFinTracker<T> {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 write!(f, "FloatWithNonFinTracker<{}>", std::any::type_name::<T>())
17 }
18}
19
20impl<T> Default for FloatWithNonFinTracker<T> {
21 fn default() -> Self {
22 Self(PhantomData)
23 }
24}
25
26impl<T: Expected> Tracker for FloatWithNonFinTracker<T> {
27 type Target = T;
28
29 #[inline(always)]
30 fn allow_duplicates(&self) -> bool {
31 false
32 }
33}
34
35#[repr(transparent)]
36pub struct FloatWithNonFinite<T>(T);
37
38impl<T: Default> Default for FloatWithNonFinite<T> {
39 fn default() -> Self {
40 Self(Default::default())
41 }
42}
43
44impl<T: Expected> TrackerFor for FloatWithNonFinite<T> {
45 type Tracker = FloatWithNonFinTracker<T>;
46}
47
48impl<T> Expected for FloatWithNonFinite<T> {
49 fn expecting(formatter: &mut fmt::Formatter) -> fmt::Result {
50 write!(formatter, stringify!(T))
51 }
52}
53
54pub trait FloatWithNonFinDesHelper: Sized {
57 type Target;
58}
59
60impl FloatWithNonFinDesHelper for f32 {
61 type Target = FloatWithNonFinite<f32>;
62}
63
64impl FloatWithNonFinDesHelper for f64 {
65 type Target = FloatWithNonFinite<f64>;
66}
67
68impl<T: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for Option<T> {
69 type Target = Option<T::Target>;
70}
71
72impl<T: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for Vec<T> {
73 type Target = Vec<T::Target>;
74}
75
76impl<K, V: FloatWithNonFinDesHelper> FloatWithNonFinDesHelper for BTreeMap<K, V> {
77 type Target = BTreeMap<K, V::Target>;
78}
79
80impl<K, V: FloatWithNonFinDesHelper, S> FloatWithNonFinDesHelper for HashMap<K, V, S> {
81 type Target = HashMap<K, V::Target, S>;
82}
83
84impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, FloatWithNonFinTracker<T>>
85where
86 T: serde::Deserialize<'de> + Float + ToPrimitive + FromPrimitive,
87 FloatWithNonFinTracker<T>: Tracker<Target = T>,
88{
89 type Value = ();
90
91 fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
92 where
93 D: serde::Deserializer<'de>,
94 {
95 struct Visitor<T>(PhantomData<T>);
96
97 impl<T> Default for Visitor<T> {
98 fn default() -> Self {
99 Self(PhantomData)
100 }
101 }
102
103 macro_rules! visit_convert_to_float {
104 ($visitor_func:ident, $conv_func:ident, $ty:ident) => {
105 fn $visitor_func<E>(self, v: $ty) -> Result<Self::Value, E>
106 where
107 E: Error,
108 {
109 match T::$conv_func(v) {
110 Some(v) => Ok(v),
111 None => Err(E::custom(format!("unable to extract float-type from {}", v))),
112 }
113 }
114 };
115 }
116
117 impl<'de, T> serde::de::Visitor<'de> for Visitor<T>
118 where
119 T: serde::Deserialize<'de> + Float + ToPrimitive + FromPrimitive,
120 {
121 type Value = T;
122
123 visit_convert_to_float!(visit_f32, from_f32, f32);
124
125 visit_convert_to_float!(visit_f64, from_f64, f64);
126
127 visit_convert_to_float!(visit_u8, from_u8, u8);
128
129 visit_convert_to_float!(visit_u16, from_u16, u16);
130
131 visit_convert_to_float!(visit_u32, from_u32, u32);
132
133 visit_convert_to_float!(visit_u64, from_u64, u64);
134
135 visit_convert_to_float!(visit_i8, from_i8, i8);
136
137 visit_convert_to_float!(visit_i16, from_i16, i16);
138
139 visit_convert_to_float!(visit_i32, from_i32, i32);
140
141 visit_convert_to_float!(visit_i64, from_i64, i64);
142
143 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
144 write!(formatter, stringify!(T))
145 }
146
147 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
148 where
149 E: Error,
150 {
151 match v {
152 "Infinity" => Ok(T::infinity()),
153 "-Infinity" => Ok(T::neg_infinity()),
154 "NaN" => Ok(T::nan()),
155 _ => Err(E::custom(format!("unrecognized floating string: {}", v))),
156 }
157 }
158
159 fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
160 where
161 E: Error,
162 {
163 self.visit_str(v)
164 }
165
166 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
167 where
168 E: Error,
169 {
170 self.visit_str(&v)
171 }
172 }
173
174 *self.value = de.deserialize_any(Visitor::default())?;
175 Ok(())
176 }
177}
178
179impl<'de, T> TrackerDeserializer<'de> for FloatWithNonFinTracker<T>
180where
181 T: serde::Deserialize<'de> + Float + FromPrimitive,
182 FloatWithNonFinTracker<T>: Tracker<Target = T>,
183{
184 fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
185 where
186 D: DeserializeContent<'de>,
187 {
188 deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
189 }
190}
191
192impl<T: Float + FromPrimitive + Display> serde::Serialize for FloatWithNonFinite<T> {
195 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
196 where
197 S: serde::Serializer,
198 {
199 match (self.0.is_nan(), self.0.is_infinite(), self.0.is_sign_negative()) {
200 (true, _, _) => serializer.serialize_str("NaN"),
201 (false, true, true) => serializer.serialize_str("-Infinity"),
202 (false, true, false) => serializer.serialize_str("Infinity"),
203 _ => {
204 let converted = self
205 .0
206 .to_f64()
207 .ok_or_else(|| serde::ser::Error::custom(format!("Failed to convert {} to f64", self.0)))?;
208 serializer.serialize_f64(converted)
209 }
210 }
211 }
212}
213
214unsafe trait FloatWithNonFinSerHelper: Sized {
218 type Helper: Sized;
219
220 fn cast(value: &Self) -> &Self::Helper {
221 unsafe { &*(value as *const Self as *const Self::Helper) }
223 }
224}
225
226unsafe impl FloatWithNonFinSerHelper for f32 {
228 type Helper = FloatWithNonFinite<f32>;
229}
230
231unsafe impl FloatWithNonFinSerHelper for f64 {
233 type Helper = FloatWithNonFinite<f64>;
234}
235
236unsafe impl<T: Float + FromPrimitive> FloatWithNonFinSerHelper for FloatWithNonFinite<T> {
238 type Helper = FloatWithNonFinite<T>;
239}
240
241unsafe impl<T: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for Option<T> {
243 type Helper = Option<T::Helper>;
244}
245
246unsafe impl<T: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for Vec<T> {
248 type Helper = Vec<T::Helper>;
249}
250
251unsafe impl<K, V: FloatWithNonFinSerHelper> FloatWithNonFinSerHelper for BTreeMap<K, V> {
253 type Helper = BTreeMap<K, V::Helper>;
254}
255
256unsafe impl<K, V: FloatWithNonFinSerHelper, S> FloatWithNonFinSerHelper for HashMap<K, V, S> {
258 type Helper = HashMap<K, V::Helper, S>;
259}
260
261#[allow(private_bounds)]
262pub fn serialize_floats_with_non_finite<V, S>(value: &V, serializer: S) -> Result<S::Ok, S::Error>
263where
264 V: FloatWithNonFinSerHelper,
265 V::Helper: serde::Serialize,
266 S: serde::Serializer,
267{
268 V::cast(value).serialize(serializer)
269}