diff --git a/autometrics-macros/Cargo.toml b/autometrics-macros/Cargo.toml index 881163b..97d3513 100644 --- a/autometrics-macros/Cargo.toml +++ b/autometrics-macros/Cargo.toml @@ -14,6 +14,7 @@ categories = { workspace = true } proc-macro = true [dependencies] +Inflector = "0.11.4" percent-encoding = "2.2" proc-macro2 = "1" quote = "1" diff --git a/autometrics-macros/src/lib.rs b/autometrics-macros/src/lib.rs index c9e25e3..d8ae29f 100644 --- a/autometrics-macros/src/lib.rs +++ b/autometrics-macros/src/lib.rs @@ -1,9 +1,13 @@ use crate::parse::{AutometricsArgs, Item}; +use inflector::Inflector; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; -use proc_macro2::TokenStream; +use proc_macro2::{Ident, TokenStream}; use quote::quote; use std::env; -use syn::{parse_macro_input, ImplItem, ItemFn, ItemImpl, Result}; +use syn::{ + parse_macro_input, Attribute, Data, DataEnum, DeriveInput, Error, ImplItem, ItemFn, ItemImpl, + Lit, Meta, NestedMeta, Result, +}; mod parse; @@ -129,6 +133,18 @@ pub fn autometrics( output.into() } +#[proc_macro_derive(AutometricsLabel, attributes(autometrics_label))] +pub fn derive_autometrics_label(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let result = derive_autometrics_label_impl(input); + let output = match result { + Ok(output) => output, + Err(err) => err.into_compile_error(), + }; + + output.into() +} + /// Add autometrics instrumentation to a single function fn instrument_function(args: &AutometricsArgs, item: ItemFn) -> Result { let sig = item.sig; @@ -176,10 +192,9 @@ fn instrument_function(args: &AutometricsArgs, item: ItemFn) -> Result, attach that as a label - let value_type = (&result).__autometrics_static_str(); + let value_type = (&result).get_label().map(|(_, v)| v); CounterLabels::new( #function_name, module_path!(), @@ -194,8 +209,8 @@ fn instrument_function(args: &AutometricsArgs, item: ItemFn) -> Result Result String { format!("sum by (function, module) {gauge_name}{{{label_key}=\"{label_value}\"}}") } + +fn derive_autometrics_label_impl(input: DeriveInput) -> Result { + let variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + _ => { + return Err(Error::new_spanned( + input, + "#[derive(AutometricsLabel}] is only supported for enums", + )); + } + }; + + // Use the key provided or the snake case version of the enum name + let label_key = { + let attrs: Vec<_> = input + .attrs + .iter() + .filter(|attr| attr.path.is_ident("autometrics_label")) + .collect(); + + let key_from_attr = match attrs.len() { + 0 => None, + 1 => get_label_attr(attrs[0], "key")?, + _ => { + let mut error = syn::Error::new_spanned( + attrs[1], + "redundant `autometrics_label(key)` attribute", + ); + error.combine(syn::Error::new_spanned(attrs[0], "note: first one here")); + return Err(error); + } + }; + + let key_from_attr = key_from_attr.map(|value| value.to_string()); + + // Check casing of the user-provided value + if let Some(key) = &key_from_attr { + if key.as_str() != key.to_snake_case() { + return Err(Error::new_spanned(attrs[0], "key should be snake_cased")); + } + } + + let ident = input.ident.clone(); + key_from_attr.unwrap_or_else(|| ident.clone().to_string().to_snake_case()) + }; + + let value_match_arms = variants + .into_iter() + .map(|variant| { + let attrs: Vec<_> = variant + .attrs + .iter() + .filter(|attr| attr.path.is_ident("autometrics_label")) + .collect(); + + let value_from_attr = match attrs.len() { + 0 => None, + 1 => get_label_attr(attrs[0], "value")?, + _ => { + let mut error = Error::new_spanned( + attrs[1], + "redundant `autometrics_label(value)` attribute", + ); + error.combine(Error::new_spanned(attrs[0], "note: first one here")); + return Err(error); + } + }; + + let value_from_attr = value_from_attr.map(|value| value.to_string()); + + // Check casing of the user-provided value + if let Some(value) = &value_from_attr { + if value.as_str() != value.to_snake_case() { + return Err(Error::new_spanned(attrs[0], "value should be snake_cased")); + } + } + + let ident = variant.ident; + let value = + value_from_attr.unwrap_or_else(|| ident.clone().to_string().to_snake_case()); + Ok(quote! { + Self::#ident => #value, + }) + }) + .collect::>()?; + + let ident = input.ident; + Ok(quote! { + use ::autometrics::__private::{GetLabel, COUNTER_LABEL_KEYS, linkme}; + + #[linkme::distributed_slice(COUNTER_LABEL_KEYS)] + #[linkme(crate = ::autometrics::__private::linkme)] + pub static COUNTER_LABEL_KEY: &'static str = #label_key; + + #[automatically_derived] + impl GetLabel for #ident { + fn get_label(&self) -> Option<(&'static str, &'static str)> { + Some((#label_key, match self { + #value_match_arms + })) + } + } + }) +} + +fn get_label_attr(attr: &Attribute, attr_name: &str) -> Result> { + let meta = attr.parse_meta()?; + let meta_list = match meta { + Meta::List(list) => list, + _ => return Err(Error::new_spanned(meta, "expected a list-style attribute")), + }; + + let nested = match meta_list.nested.len() { + // `#[autometrics()]` without any arguments is a no-op + 0 => return Ok(None), + 1 => &meta_list.nested[0], + _ => { + return Err(Error::new_spanned( + meta_list.nested, + "currently only a single autometrics attribute is supported", + )); + } + }; + + let label_value = match nested { + NestedMeta::Meta(Meta::NameValue(nv)) => nv, + _ => { + return Err(Error::new_spanned( + nested, + format!("expected `{attr_name} = \"\"`"), + )) + } + }; + + if !label_value.path.is_ident(attr_name) { + return Err(Error::new_spanned( + &label_value.path, + format!("unsupported autometrics attribute, expected `{attr_name}`"), + )); + } + + match &label_value.lit { + Lit::Str(s) => syn::parse_str(&s.value()).map_err(|e| Error::new_spanned(s, e)), + lit => Err(Error::new_spanned(lit, "expected string literal")), + } +} diff --git a/autometrics/Cargo.toml b/autometrics/Cargo.toml index 22f721e..dd5149e 100644 --- a/autometrics/Cargo.toml +++ b/autometrics/Cargo.toml @@ -44,6 +44,9 @@ prometheus = { version = "0.13", default-features = false, optional = true } # Used for prometheus feature const_format = { version = "0.2", features = ["rust_1_51"], optional = true } +# Used to enumerate all generated counter types +linkme = "0.3.8" + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/autometrics/src/constants.rs b/autometrics/src/constants.rs index d8b24d4..3581a81 100644 --- a/autometrics/src/constants.rs +++ b/autometrics/src/constants.rs @@ -13,8 +13,6 @@ pub const FUNCTION_KEY: &'static str = "function"; pub const MODULE_KEY: &'static str = "module"; pub const CALLER_KEY: &'static str = "caller"; pub const RESULT_KEY: &'static str = "result"; -pub const OK_KEY: &'static str = "ok"; -pub const ERROR_KEY: &'static str = "error"; pub const OBJECTIVE_NAME: &'static str = "objective.name"; pub const OBJECTIVE_PERCENTILE: &'static str = "objective.percentile"; pub const OBJECTIVE_LATENCY_THRESHOLD: &'static str = "objective.latency_threshold"; diff --git a/autometrics/src/labels.rs b/autometrics/src/labels.rs index 1f6959b..9876e04 100644 --- a/autometrics/src/labels.rs +++ b/autometrics/src/labels.rs @@ -1,9 +1,13 @@ use crate::{constants::*, objectives::*}; +use linkme::distributed_slice; use std::ops::Deref; pub(crate) type Label = (&'static str, &'static str); type ResultAndReturnTypeLabels = (&'static str, Option<&'static str>); +#[distributed_slice] +pub static COUNTER_LABEL_KEYS: [&'static str] = [..]; + /// These are the labels used for the `function.calls.count` metric. pub struct CounterLabels { pub(crate) function: &'static str, @@ -112,33 +116,6 @@ impl GaugeLabels { } } -// The following is a convoluted way to figure out if the return type resolves to a Result -// or not. We cannot simply parse the code using syn to figure out if it's a Result -// because syn doesn't do type resolution and thus would count any renamed version -// of Result as a different type. Instead, we define two traits with intentionally -// conflicting method names and use a trick based on the order in which Rust resolves -// method names to return a different value based on whether the return value is -// a Result or anything else. -// This approach is based on dtolnay's answer to this question: -// https://users.rust-lang.org/t/how-to-check-types-within-macro/33803/5 -// and this answer explains why it works: -// https://users.rust-lang.org/t/how-to-check-types-within-macro/33803/8 - -pub trait GetLabelsFromResult { - fn __autometrics_get_labels(&self) -> Option { - None - } -} - -impl GetLabelsFromResult for Result { - fn __autometrics_get_labels(&self) -> Option { - match self { - Ok(ok) => Some((OK_KEY, ok.__autometrics_static_str())), - Err(err) => Some((ERROR_KEY, err.__autometrics_static_str())), - } - } -} - pub enum LabelArray { Three([Label; 3]), Four([Label; 4]), @@ -157,12 +134,6 @@ impl Deref for LabelArray { } } -pub trait GetLabels { - fn __autometrics_get_labels(&self) -> Option { - None - } -} - /// Implement the given trait for &T and all primitive types. macro_rules! impl_trait_for_types { ($trait:ident) => { @@ -200,24 +171,19 @@ macro_rules! impl_trait_for_types { }; } -impl_trait_for_types!(GetLabels); - -pub trait GetStaticStrFromIntoStaticStr<'a> { - fn __autometrics_static_str(&'a self) -> Option<&'static str>; -} - -impl<'a, T: 'a> GetStaticStrFromIntoStaticStr<'a> for T -where - &'static str: From<&'a T>, -{ - fn __autometrics_static_str(&'a self) -> Option<&'static str> { - Some(self.into()) +pub trait GetLabel { + fn get_label(&self) -> Option<(&'static str, &'static str)> { + None } } -pub trait GetStaticStr { - fn __autometrics_static_str(&self) -> Option<&'static str> { - None +impl GetLabel for Result { + fn get_label(&self) -> Option<(&'static str, &'static str)> { + match self { + Ok(v) => (*v).get_label(), + Err(v) => (*v).get_label(), + } } } -impl_trait_for_types!(GetStaticStr); + +impl_trait_for_types!(GetLabel); diff --git a/autometrics/src/lib.rs b/autometrics/src/lib.rs index d93e205..01604e3 100644 --- a/autometrics/src/lib.rs +++ b/autometrics/src/lib.rs @@ -13,7 +13,12 @@ mod prometheus_exporter; mod task_local; mod tracker; -pub use autometrics_macros::autometrics; +pub extern crate linkme; + +pub use labels::GetLabel; + +pub extern crate autometrics_macros; +pub use autometrics_macros::{autometrics, AutometricsLabel}; // Optional exports #[cfg(feature = "prometheus-exporter")] @@ -47,4 +52,9 @@ pub mod __private { LocalKey { inner: CALLER_KEY } }; + + /// Re-export the linkme crate so that it can be used in the code generated by the autometrics macro + pub mod linkme { + pub use ::linkme::*; + } } diff --git a/autometrics/src/tracker/prometheus.rs b/autometrics/src/tracker/prometheus.rs index d543319..c9097b9 100644 --- a/autometrics/src/tracker/prometheus.rs +++ b/autometrics/src/tracker/prometheus.rs @@ -8,6 +8,7 @@ use prometheus::{ IntCounterVec, IntGaugeVec, }; use std::time::Instant; +use crate::__private::COUNTER_LABEL_KEYS; const COUNTER_NAME_PROMETHEUS: &str = str_replace!(COUNTER_NAME, ".", "_"); const HISTOGRAM_NAME_PROMETHEUS: &str = str_replace!(HISTOGRAM_NAME, ".", "_"); @@ -17,19 +18,20 @@ const OBJECTIVE_PERCENTILE_PROMETHEUS: &str = str_replace!(OBJECTIVE_PERCENTILE, const OBJECTIVE_LATENCY_PROMETHEUS: &str = str_replace!(OBJECTIVE_LATENCY_THRESHOLD, ".", "_"); static COUNTER: Lazy = Lazy::new(|| { + let mut keys = vec![ + FUNCTION_KEY, + MODULE_KEY, + CALLER_KEY, + RESULT_KEY, + OBJECTIVE_NAME_PROMETHEUS, + OBJECTIVE_PERCENTILE_PROMETHEUS, + ]; + keys.extend_from_slice(&COUNTER_LABEL_KEYS); + register_int_counter_vec!( COUNTER_NAME_PROMETHEUS, COUNTER_DESCRIPTION, - &[ - FUNCTION_KEY, - MODULE_KEY, - CALLER_KEY, - RESULT_KEY, - OK_KEY, - ERROR_KEY, - OBJECTIVE_NAME_PROMETHEUS, - OBJECTIVE_PERCENTILE_PROMETHEUS, - ] + &keys ) .expect(formatcp!( "Failed to register {COUNTER_NAME_PROMETHEUS} counter" @@ -82,36 +84,38 @@ impl TrackMetrics for PrometheusTracker { fn finish(self, counter_labels: &CounterLabels, histogram_labels: &HistogramLabels) { let duration = self.start.elapsed().as_secs_f64(); + // Put the label values in the same order as the keys in the counter definition + let mut label_values = vec![ + counter_labels.function, + counter_labels.module, + counter_labels.caller, + counter_labels.result.unwrap_or_default().0, + counter_labels + .objective + .as_ref() + .map(|obj| obj.0) + .unwrap_or(""), + counter_labels + .objective + .as_ref() + .map(|obj| obj.1.as_str()) + .unwrap_or(""), + ]; + + // Extend label_values with whatever user-defined label keys were defined + label_values.extend(COUNTER_LABEL_KEYS.iter() + .map(|label_key| { + // See if we can find a value for this user-defined label_key + let result = counter_labels.result.unwrap_or_default(); + if result.0 == *label_key { + result.1.unwrap_or("") + } else { + "" + } + })); + COUNTER - .with_label_values( - // Put the label values in the same order as the keys in the counter definition - &[ - counter_labels.function, - counter_labels.module, - counter_labels.caller, - counter_labels.result.unwrap_or_default().0, - if let Some((OK_KEY, Some(return_value_type))) = counter_labels.result { - return_value_type - } else { - "" - }, - if let Some((ERROR_KEY, Some(return_value_type))) = counter_labels.result { - return_value_type - } else { - "" - }, - counter_labels - .objective - .as_ref() - .map(|obj| obj.0) - .unwrap_or(""), - counter_labels - .objective - .as_ref() - .map(|obj| obj.1.as_str()) - .unwrap_or(""), - ], - ) + .with_label_values(&label_values) .inc(); HISTOGRAM diff --git a/autometrics/tests/autometrics_label.rs b/autometrics/tests/autometrics_label.rs new file mode 100644 index 0000000..7cbb933 --- /dev/null +++ b/autometrics/tests/autometrics_label.rs @@ -0,0 +1,70 @@ +use autometrics::GetLabel; +use autometrics_macros::AutometricsLabel; + +#[test] +fn custom_trait_implementation() { + struct CustomResult; + + impl GetLabel for CustomResult { + fn get_label(&self) -> Option<(&'static str, &'static str)> { + Some(("ok", "my-result")) + } + } + + assert_eq!(Some(("ok", "my-result")), CustomResult {}.get_label()); +} + +#[test] +fn manual_enum() { + enum MyFoo { + A, + B, + } + + impl GetLabel for MyFoo { + fn get_label(&self) -> Option<(&'static str, &'static str)> { + Some(("hello", match self { + MyFoo::A => "a", + MyFoo::B => "b", + })) + } + } + + assert_eq!(Some(("hello", "a")), MyFoo::A.get_label()); + assert_eq!(Some(("hello", "b")), MyFoo::B.get_label()); +} + +#[test] +fn derived_enum() { + #[derive(AutometricsLabel)] + #[autometrics_label(key = "my_foo")] + enum MyFoo { + #[autometrics_label(value = "hello")] + Alpha, + #[autometrics_label()] + BetaValue, + Charlie, + } + + assert_eq!(Some(("my_foo", "hello")), MyFoo::Alpha.get_label()); + assert_eq!(Some(("my_foo", "beta_value")), MyFoo::BetaValue.get_label()); + assert_eq!(Some(("my_foo", "charlie")), MyFoo::Charlie.get_label()); + + // A custom type that doesn't implement GetLabel + struct CustomType(u32); + + let result: Result = Ok(CustomType(123)); + assert_eq!(None, result.get_label()); + + let result: Result = Err(MyFoo::Alpha); + assert_eq!(Some(("my_foo", "hello")), result.get_label()); + + let result: Result = Ok(MyFoo::Alpha); + assert_eq!(Some(("my_foo", "hello")), result.get_label()); + + let result: Result = Err(CustomType(123)); + assert_eq!(None, result.get_label()); + + let result: Result = Ok(CustomType(123)); + assert_eq!(None, result.get_label()); +} diff --git a/autometrics/tests/integration_test.rs b/autometrics/tests/integration_test.rs index c539131..5a9206f 100644 --- a/autometrics/tests/integration_test.rs +++ b/autometrics/tests/integration_test.rs @@ -1,4 +1,5 @@ use autometrics::{autometrics, objectives::*}; +use autometrics_macros::AutometricsLabel; const OBJECTIVE: Objective = Objective::new("test") .success_rate(ObjectivePercentile::P99) @@ -11,8 +12,10 @@ fn main() { add(1, 2); other_function().unwrap(); + derived_label_function(); let result = autometrics::encode_global_metrics().unwrap(); + println!("{}", result); assert_ne!(result, ""); } @@ -70,3 +73,15 @@ pub fn some_function() -> Option { pub fn none_function() -> Option { Some("Hello world!".to_string()) } + +#[autometrics] +pub fn derived_label_function() -> DerivedLabel { + DerivedLabel::Foo +} + +#[derive(AutometricsLabel)] +#[autometrics_label(key = "derived_label")] +pub enum DerivedLabel { + #[autometrics_label()] + Foo, +} diff --git a/examples/full-api/Cargo.toml b/examples/full-api/Cargo.toml index c9c568c..957ab9c 100644 --- a/examples/full-api/Cargo.toml +++ b/examples/full-api/Cargo.toml @@ -11,6 +11,5 @@ axum = { version = "0.6", features = ["json"] } rand = "0.8" reqwest = { version = "0.11", features = ["json"] } serde = { version = "1", features = ["derive"] } -strum = { version = "0.24", features = ["derive"] } thiserror = "1" tokio = { version = "1", features = ["full"] } diff --git a/examples/full-api/src/error.rs b/examples/full-api/src/error.rs index 13b6f85..15b2482 100644 --- a/examples/full-api/src/error.rs +++ b/examples/full-api/src/error.rs @@ -1,24 +1,22 @@ +use autometrics::AutometricsLabel; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use strum::IntoStaticStr; -use thiserror::Error; -// We're using `thiserror` to define our error type, and we're using `strum` to -// enable the error variants to be turned into &'static str's, which -// will actually become another label on the call counter metric. +// We're using the `AutometricsLabel` derive to enable the error variants to be turned into labels +// on the call counter metric. // -// In this case, the label will be `error` = `not_found`, `bad_request`, or `internal`. +// In this case, the label will be `error` = `not_found`, `bad_request`, or `internal_server_error`. // // Instead of looking at high-level HTTP status codes in our metrics, // we'll instead see the actual variant name of the error. -#[derive(Debug, Error, IntoStaticStr)] -#[strum(serialize_all = "snake_case")] +#[derive(Debug, AutometricsLabel)] +#[autometrics_label(key = "error")] pub enum ApiError { - #[error("User not found")] + #[autometrics_label()] NotFound, - #[error("Bad request")] + #[autometrics_label()] BadRequest, - #[error("Internal server error")] + #[autometrics_label(value = "internal_server_error")] Internal, }