kms_secp256k1_api/services/
aws_kms_client_service.rs

1use crate::{
2    config::{AwsConfig, HashType},
3    services::{keys_service::KeyEntry, kms_client_service::KmsClientService},
4};
5use aws_config::SdkConfig;
6use aws_credential_types::Credentials;
7use aws_sdk_kms::{
8    Client as KmsClient,
9    primitives::Blob,
10    types::{KeyMetadata, KeySpec, KeyUsageType, MessageType, OriginType, SigningAlgorithmSpec},
11};
12use aws_types::region::Region;
13use base64::Engine;
14use base64::engine::general_purpose::STANDARD;
15use k256::sha2::{Digest, Sha256};
16// use sha3::{Digest as Sha3Digest, Sha3_256};
17use std::vec;
18use tracing::{error, info};
19
20pub struct AWSKmsClientService {
21    create: KmsClient,
22    sign: KmsClient,
23    delete: Option<KmsClient>,
24    list: Option<KmsClient>,
25    pub hash_type: HashType,
26}
27
28impl AWSKmsClientService {
29    /// Creates a new `AWSKmsClientService` with the provided AWS credentials.
30    ///
31    /// # Errors
32    ///
33    /// Returns an error string if the SDK configuration fails.
34    pub async fn new(aws_config: AwsConfig) -> Result<Self, String> {
35        let region = Region::new(aws_config.region.clone());
36        let endpoint = aws_config.endpoint.clone();
37
38        // Log AWS configuration for debugging
39        info!(
40            "Initializing AWS KMS Client - Region: {}, Endpoint: {}",
41            aws_config.region, endpoint
42        );
43
44        let create_creds = Credentials::new(
45            aws_config.create.access_key_id,
46            aws_config.create.secret_access_key,
47            None,
48            None,
49            "create_kms_credentials",
50        );
51
52        let sign_creds = Credentials::new(
53            aws_config.sign.access_key_id,
54            aws_config.sign.secret_access_key,
55            None,
56            None,
57            "sign_kms_credentials",
58        );
59
60        let create_sdk_config =
61            aws_config_to_sdk_config(region.clone(), endpoint.clone(), create_creds).await;
62        let sign_sdk_config =
63            aws_config_to_sdk_config(region.clone(), endpoint.clone(), sign_creds).await;
64
65        let create = KmsClient::new(&create_sdk_config);
66        let sign = KmsClient::new(&sign_sdk_config);
67
68        // Conditionally build delete if credentials exist
69        let delete = if let Some(delete_creds) = &aws_config.delete {
70            let delete_creds = Credentials::new(
71                delete_creds.access_key_id.clone(),
72                delete_creds.secret_access_key.clone(),
73                None,
74                None,
75                "delete_kms_credentials",
76            );
77            let delete_sdk_config =
78                aws_config_to_sdk_config(region.clone(), endpoint.clone(), delete_creds).await;
79            Some(KmsClient::new(&delete_sdk_config))
80        } else {
81            None
82        };
83
84        let list = if let Some(list_creds) = &aws_config.list {
85            let delete_creds = Credentials::new(
86                list_creds.access_key_id.clone(),
87                list_creds.secret_access_key.clone(),
88                None,
89                None,
90                "list_kms_credentials",
91            );
92            let delete_sdk_config =
93                aws_config_to_sdk_config(region.clone(), endpoint, delete_creds).await;
94            Some(KmsClient::new(&delete_sdk_config))
95        } else {
96            None
97        };
98
99        Ok(Self {
100            create,
101            sign,
102            delete,
103            list,
104            hash_type: aws_config.hash_type,
105        })
106    }
107
108    // Casper	SHA256 (x2)	RAW	Transaction hash, AWS re-hashes with SHA256
109    // Ethereum (EIP-155) Keccak256	DIGEST, AWS does not re-hashes
110    // Other	SHA3-256 DIGEST, AWS does not re-hashes
111    fn hash(&self, data: &[u8]) -> Vec<u8> {
112        match self.hash_type {
113            HashType::Sha256 => {
114                let mut hasher = Sha256::new();
115                hasher.update(data);
116                hasher.finalize().to_vec() // Casper Sha256 of Sha256 transaction_hash
117            }
118            HashType::Keccak256 => data.to_vec(), // EIP-155 transaction hash is Keccak256 hash already
119
120                                                  // HashType::Sha3_256 => {
121                                                  //     let mut hasher = Sha3_256::new();
122                                                  //     hasher.update(data);
123                                                  //     hasher.finalize().to_vec()
124                                                  // }
125        }
126    }
127}
128
129#[async_trait::async_trait]
130impl KmsClientService for AWSKmsClientService {
131    async fn create_key(&self) -> Result<(String, String), String> {
132        let kms_client = &self.create;
133
134        let create_key_output = kms_client
135            .create_key()
136            .description("SECP256K1 Key")
137            .key_usage(KeyUsageType::SignVerify)
138            .origin(OriginType::AwsKms)
139            .key_spec(KeySpec::EccSecgP256K1)
140            .send()
141            .await
142            .map_err(|e| {
143                let msg = format!("Error creating key: {e:?}");
144                error!("{}", &msg);
145                msg
146            })?;
147
148        let key_id = create_key_output
149            .key_metadata
150            .map(|meta| meta.key_id)
151            .ok_or_else(|| {
152                let msg = "No KeyId found".to_string();
153                error!("{}", &msg);
154                msg
155            })?;
156
157        // info!("Key created: {}", key_id);
158
159        // Fetch the public key
160        let public_key_base64 = self.get_public_key_base64(kms_client, &key_id).await?;
161
162        // info!("Public key (base64): {}", public_key_base64);
163        Ok((key_id, public_key_base64))
164    }
165
166    async fn create_alias(&self, key_id: &str, alias: &str) -> Result<(), String> {
167        let alias_name = Self::format_alias(alias);
168        let kms_client = &self.create;
169
170        kms_client
171            .create_alias()
172            .alias_name(&alias_name)
173            .target_key_id(key_id)
174            .send()
175            .await
176            .map_err(|e| {
177                let msg = format!("Error creating alias: {e:?}");
178                error!("{}", &msg);
179                msg
180            })?;
181        //   info!("Alias key: {}", alias_name);
182        Ok(())
183    }
184
185    async fn sign(&self, transaction_hash_hex: &str, key: &str) -> Result<String, String> {
186        let kms_client = &self.sign;
187
188        // Resolve key to key ID
189        let key_metadata = self.describe_key_metadata(kms_client, key).await?;
190        let key_id = key_metadata.key_id;
191
192        let data = hex::decode(transaction_hash_hex).map_err(|e| {
193            let msg = format!("Failed to decode transaction hash hex: {e}");
194            error!("{}", msg);
195            msg
196        })?;
197
198        // let transaction_hash = STANDARD.encode(data.clone());
199        // info!("transaction_hash: {}", transaction_hash);
200
201        let input_data = self.hash(&data);
202
203        // Sign the message
204        let sign_output = kms_client
205            .sign()
206            .key_id(&key_id)
207            .message(Blob::new(input_data))
208            .message_type(MessageType::Digest)
209            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
210            .send()
211            .await
212            .map_err(|e| {
213                let msg = format!("Error calling sign on KMS: {e:?}");
214                error!("{}", msg);
215                msg
216            })?;
217
218        let signature = sign_output.signature.ok_or_else(|| {
219            let msg = "No signature returned from AWS KMS".to_string();
220            error!("{}", msg);
221            msg
222        })?;
223
224        let signature = signature.as_ref();
225
226        // Log signature
227        // let signature_hex = hex::encode(signature);
228        // info!("signature hex: {}", signature_hex);
229        let signature = STANDARD.encode(signature);
230        // info!("signature: {}", signature);
231
232        Ok(signature)
233    }
234
235    async fn verify(
236        &self,
237        transaction_hash_hex: &str,
238        signature_asn1_base64: &str,
239        key: &str,
240    ) -> Result<bool, String> {
241        let kms_client = &self.sign;
242
243        // Resolve key to key ID
244        let key_metadata = self.describe_key_metadata(kms_client, key).await?;
245        let key_id = key_metadata.key_id;
246
247        // Decode the digest hex to bytes
248        let digest_bytes = hex::decode(transaction_hash_hex).map_err(|e| {
249            let msg = format!("Failed to decode transaction hash hex: {e}");
250            error!("{}", msg);
251            msg
252        })?;
253
254        // Decode the signature base64 to bytes (ASN.1 DER format expected)
255        let signature_bytes = STANDARD.decode(signature_asn1_base64).map_err(|e| {
256            let msg = format!("Failed to decode signature base64: {e}");
257            error!("{}", msg);
258            msg
259        })?;
260
261        let input_data = self.hash(&digest_bytes);
262
263        // Call AWS KMS verify
264        let verify_output = kms_client
265            .verify()
266            .key_id(&key_id)
267            .message(Blob::new(input_data))
268            .message_type(MessageType::Digest)
269            .signature(Blob::new(signature_bytes))
270            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
271            .send()
272            .await
273            .map_err(|e| {
274                let msg = format!("Error calling verify on KMS: {e:?}");
275                error!("{}", msg);
276                msg
277            })?;
278
279        // The 'signature_valid' field indicates validity
280        Ok(verify_output.signature_valid)
281    }
282
283    async fn delete_key(&self, key: &str) -> Result<bool, String> {
284        let Some(kms_client) = &self.delete else {
285            tracing::warn!(
286                "Attempted to delete key, but delete_kms client is not configured/enabled"
287            );
288            return Ok(false);
289        };
290
291        // Resolve key to key ID
292        let key_metadata = self.describe_key_metadata(kms_client, key).await?;
293
294        let key_id = key_metadata.key_id.clone();
295        let alias_name = Self::format_alias(key);
296
297        kms_client
298            .delete_alias()
299            .alias_name(&alias_name)
300            .send()
301            .await
302            .map_err(|e| {
303                let msg = format!("Failed to delete alias {alias_name}: {e:?}");
304                error!("{}", msg);
305                msg
306            })?;
307
308        kms_client
309            .schedule_key_deletion()
310            .key_id(&key_id)
311            .pending_window_in_days(7)
312            .send()
313            .await
314            .map_err(|e| {
315                let msg = format!("Failed to schedule deletion for key {key_id}: {e:?}");
316                error!("{}", msg);
317                msg
318            })?;
319
320        // info!(
321        //     "Key {} (alias {}) scheduled for deletion",
322        //     key_id, alias_name
323        // );
324        Ok(true)
325    }
326
327    async fn list_keys(&self) -> Result<Vec<KeyEntry>, String> {
328        let Some(kms_client) = &self.list else {
329            tracing::warn!("Attempted to list keys, but list_kms client is not configured/enabled");
330            return Ok(vec![]);
331        };
332
333        let mut paginator = kms_client.list_aliases().into_paginator().send();
334        let mut results = Vec::<KeyEntry>::new();
335
336        while let Some(page_result) = paginator.next().await {
337            let page = page_result.map_err(|e| {
338                let msg = format!("Failed to list aliases: {e:?}");
339                error!("{}", msg);
340                msg
341            })?;
342
343            let aliases = page.aliases.unwrap_or_default();
344
345            for alias in aliases {
346                let (Some(alias_name), Some(key_id)) = (alias.alias_name, alias.target_key_id)
347                else {
348                    continue;
349                };
350
351                if alias_name.starts_with("alias/aws/") {
352                    continue;
353                }
354
355                let key_metadata = self.describe_key_metadata(kms_client, &alias_name).await?;
356
357                let is_enabled =
358                    key_metadata.key_state.as_ref() == Some(&aws_sdk_kms::types::KeyState::Enabled);
359
360                if !is_enabled {
361                    continue;
362                }
363
364                // Skip symmetric keys as they don't have public keys
365                // Only process asymmetric keys (ECC, RSA, etc.)
366                if key_metadata.key_spec.as_ref() == Some(&KeySpec::SymmetricDefault) {
367                    continue;
368                }
369
370                let address = alias_name.trim_start_matches("alias/").to_string();
371                let public_key_base64 = self.get_public_key_base64(kms_client, &key_id).await?;
372
373                results.push(KeyEntry {
374                    address: address.into(),
375                    public_key_base64: public_key_base64.into(),
376                    public_key: None.into(), // recomputed later per keys service list_keys
377                    key_id: key_id.into(),
378                });
379            }
380        }
381
382        Ok(results)
383    }
384
385    /// Resolves an alias and fetches the base64-encoded public key associated with it.
386    ///
387    /// # Arguments
388    /// * `kms_client` - The AWS KMS client used for the request.
389    /// * `alias` - The alias name (with or without the `alias/` prefix).
390    ///
391    /// # Returns
392    /// * `Ok(String)` - The base64-encoded public key.
393    /// * `Err(String)` - If resolving the alias or fetching the key fails.
394    async fn get_public_key(&self, alias: &str) -> Result<String, String> {
395        // Ensure alias is in the correct format
396        let alias_name = if alias.starts_with("alias/") {
397            alias.to_string()
398        } else {
399            format!("alias/{alias}")
400        };
401
402        let kms_client = &self.sign; // Sign credentials are used to get a public key
403
404        let key_metadata = self.describe_key_metadata(kms_client, &alias_name).await?;
405        let key_id = &key_metadata.key_id;
406
407        self.get_public_key_base64(kms_client, key_id).await
408    }
409}
410
411impl AWSKmsClientService {
412    /// Retrieves the public key for the given key ID from AWS KMS.
413    ///
414    /// # Arguments
415    /// * `kms_client` - A reference to the KMS client used to call `get_public_key`.
416    /// * `key_id` - The ID of the key to retrieve the public key for.
417    ///
418    /// # Returns
419    /// * `Ok(String)` - The base64-encoded public key.
420    /// * `Err(String)` - Error message if the call fails or the key is missing.
421    async fn get_public_key_base64(
422        &self,
423        kms_client: &aws_sdk_kms::Client,
424        key_id: &str,
425    ) -> Result<String, String> {
426        let pubkey_resp = kms_client
427            .get_public_key()
428            .key_id(key_id)
429            .send()
430            .await
431            .map_err(|e| {
432                let msg = format!("Could not get public key for key_id {key_id}: {e:?}");
433                error!("{}", msg);
434                msg
435            })?;
436
437        let pubkey = pubkey_resp.public_key.ok_or_else(|| {
438            let msg = format!("Missing public key for key_id {key_id}");
439            error!("{}", msg);
440            msg
441        })?;
442
443        Ok(STANDARD.encode(pubkey.as_ref()))
444    }
445
446    /// Retrieves `KeyMetadata` for a given key key by calling AWS KMS `DescribeKey`.
447    ///
448    /// # Arguments
449    /// * `key` - The key name without the `alias/` prefix.
450    ///
451    /// # Returns
452    /// * `Ok(KeyMetadata)` if the key is found.
453    /// * `Err(String)` if the request fails or metadata is missing.
454    async fn describe_key_metadata(
455        &self,
456        kms_client: &aws_sdk_kms::Client,
457        key: &str,
458    ) -> Result<KeyMetadata, String> {
459        let alias_name = Self::format_alias(key);
460
461        let output = kms_client
462            .describe_key()
463            .key_id(&alias_name)
464            .send()
465            .await
466            .map_err(|e| {
467                let msg = format!("Failed to describe key for alias {alias_name}: {e:?}");
468                error!("{}", msg);
469                msg
470            })?;
471
472        output.key_metadata.ok_or_else(|| {
473            let msg = format!("KeyMetadata not found for alias {alias_name}");
474            error!("{}", msg);
475            msg
476        })
477    }
478
479    fn format_alias(alias: &str) -> String {
480        if alias.starts_with("alias/") {
481            alias.to_string()
482        } else {
483            format!("alias/{alias}")
484        }
485    }
486}
487
488async fn aws_config_to_sdk_config(
489    region: Region,
490    endpoint: String,
491    creds: Credentials,
492) -> SdkConfig {
493    let config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
494    config_loader
495        .region(region)
496        .endpoint_url(endpoint)
497        .credentials_provider(creds)
498        .load()
499        .await
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::{
506        config::{AwsConfig, AwsCreds},
507        constants::{AWS_KMS_ENDPOINT_PATTERN, DEFAULT_AWS_REGION},
508    };
509
510    #[tokio::test]
511    async fn test_aws_kms_client_service_new_success() {
512        let dummy_creds = AwsCreds {
513            access_key_id: "dummy_access_key".into(),
514            secret_access_key: "dummy_secret_key".into(),
515        };
516
517        let aws_config = AwsConfig {
518            region: DEFAULT_AWS_REGION.into(),
519            endpoint: AWS_KMS_ENDPOINT_PATTERN.replace("{}", DEFAULT_AWS_REGION),
520            create: dummy_creds.clone(),
521            sign: dummy_creds.clone(),
522            delete: None,
523            list: None,
524            ..Default::default()
525        };
526
527        let result = AWSKmsClientService::new(aws_config).await;
528
529        assert!(
530            result.is_ok(),
531            "Service construction failed: {:?}",
532            result.err()
533        );
534
535        let service = result.unwrap();
536
537        assert!(service.delete.is_none());
538        assert!(service.list.is_none());
539        assert_eq!(service.hash_type, HashType::Sha256);
540    }
541
542    #[tokio::test]
543    async fn test_aws_kms_client_service_new_success_list_delete() {
544        let dummy_creds = AwsCreds {
545            access_key_id: "dummy_access_key".into(),
546            secret_access_key: "dummy_secret_key".into(),
547        };
548
549        let aws_config = AwsConfig {
550            region: DEFAULT_AWS_REGION.into(),
551            endpoint: AWS_KMS_ENDPOINT_PATTERN.replace("{}", DEFAULT_AWS_REGION),
552            create: dummy_creds.clone(),
553            sign: dummy_creds.clone(),
554            delete: Some(dummy_creds.clone()),
555            list: Some(dummy_creds.clone()),
556            hash_type: HashType::Keccak256,
557        };
558
559        let result = AWSKmsClientService::new(aws_config).await;
560
561        assert!(
562            result.is_ok(),
563            "Service construction failed: {:?}",
564            result.err()
565        );
566
567        let service = result.unwrap();
568
569        assert!(service.delete.is_some());
570        assert!(service.list.is_some());
571        assert_eq!(service.hash_type, HashType::Keccak256);
572    }
573}