chore: remove the google drive built-in extension (#4187)
This commit is contained in:
131
Cargo.lock
generated
131
Cargo.lock
generated
@@ -2543,89 +2543,6 @@ dependencies = [
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-apis-common"
|
||||
version = "7.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7530ee92a7e9247c3294ae1b84ea98474dbc27563c49a14d3938e816499bf38f"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"http 1.2.0",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"itertools 0.13.0",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"tokio",
|
||||
"url",
|
||||
"yup-oauth2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-docs1"
|
||||
version = "6.0.0+20240613"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8441d3fa1544efacb0fabf88c45ba60d424d718bb13f2a0ce2a6447efb99d14e"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"google-apis-common",
|
||||
"hyper 1.6.0",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-util",
|
||||
"mime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"tokio",
|
||||
"url",
|
||||
"yup-oauth2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-drive3"
|
||||
version = "6.0.0+20240618"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84e3944ee656d220932785cf1d8275519c0989830b9b239453983ac44f328d9f"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"google-apis-common",
|
||||
"hyper 1.6.0",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-util",
|
||||
"mime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"tokio",
|
||||
"url",
|
||||
"yup-oauth2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-sheets4"
|
||||
version = "6.0.0+20240621"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4f8ccfc6418e81d1e2ed66fad49d0487526281505b8a0ed8ee770dc7d6bb1e5"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"google-apis-common",
|
||||
"hyper 1.6.0",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-util",
|
||||
"mime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"tokio",
|
||||
"url",
|
||||
"yup-oauth2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "goose"
|
||||
version = "1.5.0"
|
||||
@@ -2785,10 +2702,6 @@ dependencies = [
|
||||
"docx-rs",
|
||||
"etcetera",
|
||||
"glob",
|
||||
"google-apis-common",
|
||||
"google-docs1",
|
||||
"google-drive3",
|
||||
"google-sheets4",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"ignore",
|
||||
@@ -4389,15 +4302,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_threads"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.4.0"
|
||||
@@ -5814,12 +5718,6 @@ version = "3.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b07779b9b918cc05650cb30f404d4d7835d26df37c235eded8a6832e2fb82cca"
|
||||
|
||||
[[package]]
|
||||
name = "seahash"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
@@ -6523,9 +6421,7 @@ checksum = "bb041120f25f8fbe8fd2dbe4671c7c2ed74d83be2e7a77529bf7e0790ae3f472"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"libc",
|
||||
"num-conv",
|
||||
"num_threads",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
"time-core",
|
||||
@@ -8083,33 +7979,6 @@ dependencies = [
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yup-oauth2"
|
||||
version = "11.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ed5f19242090128c5809f6535cc7b8d4e2c32433f6c6005800bbc20a644a7f0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"futures",
|
||||
"http 1.2.0",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-rustls 0.27.5",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"rustls 0.23.23",
|
||||
"rustls-pemfile 2.2.0",
|
||||
"seahash",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"time",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.7.35"
|
||||
|
||||
@@ -20,14 +20,6 @@ crates/goose-mcp/src/computercontroller/mod.rs::xlsx_tool
|
||||
crates/goose-mcp/src/computercontroller/pdf_tool.rs::pdf_tool
|
||||
crates/goose-mcp/src/developer/mod.rs::bash
|
||||
crates/goose-mcp/src/developer/mod.rs::new
|
||||
crates/goose-mcp/src/google_drive/google_labels.rs::doit
|
||||
crates/goose-mcp/src/google_drive/mod.rs::create_file
|
||||
crates/goose-mcp/src/google_drive/mod.rs::docs_tool
|
||||
crates/goose-mcp/src/google_drive/mod.rs::new
|
||||
crates/goose-mcp/src/google_drive/mod.rs::search_files
|
||||
crates/goose-mcp/src/google_drive/mod.rs::sharing
|
||||
crates/goose-mcp/src/google_drive/mod.rs::sheets_tool
|
||||
crates/goose-mcp/src/google_drive/mod.rs::update_label
|
||||
crates/goose-mcp/src/memory/mod.rs::new
|
||||
crates/goose-server/src/openapi.rs::convert_typed_schema
|
||||
crates/goose-server/src/openapi.rs::convert_typed_schema
|
||||
|
||||
@@ -721,7 +721,7 @@ pub async fn cli() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Mcp { name }) => {
|
||||
let _ = run_server(&name).await;
|
||||
run_server(&name).await?;
|
||||
}
|
||||
Some(Command::Session {
|
||||
command,
|
||||
|
||||
@@ -31,7 +31,6 @@ fn get_display_name(extension_id: &str) -> String {
|
||||
match extension_id {
|
||||
"developer" => "Developer Tools".to_string(),
|
||||
"computercontroller" => "Computer Controller".to_string(),
|
||||
"googledrive" => "Google Drive".to_string(),
|
||||
"memory" => "Memory".to_string(),
|
||||
"tutorial" => "Tutorial".to_string(),
|
||||
"jetbrains" => "JetBrains".to_string(),
|
||||
@@ -735,11 +734,6 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
"Developer Tools",
|
||||
"Code editing and shell access",
|
||||
)
|
||||
.item(
|
||||
"googledrive",
|
||||
"Google Drive",
|
||||
"Search and read content from google drive - additional config required",
|
||||
)
|
||||
.item("jetbrains", "JetBrains", "Connect to jetbrains IDEs")
|
||||
.item(
|
||||
"memory",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use goose_mcp::{
|
||||
ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, MemoryRouter, TutorialRouter,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use goose_mcp::{ComputerControllerRouter, DeveloperRouter, MemoryRouter, TutorialRouter};
|
||||
use mcp_server::router::RouterService;
|
||||
use mcp_server::{BoundedService, ByteTransport, Server};
|
||||
use tokio::io::{stdin, stdout};
|
||||
@@ -17,34 +15,32 @@ use nix::unistd::getpgrp;
|
||||
use nix::unistd::Pid;
|
||||
|
||||
pub async fn run_server(name: &str) -> Result<()> {
|
||||
// Initialize logging
|
||||
crate::logging::setup_logging(Some(&format!("mcp-{name}")), None)?;
|
||||
|
||||
if name == "googledrive" || name == "google_drive" {
|
||||
return Err(anyhow!(
|
||||
"the built-in Google Drive extension has been removed"
|
||||
));
|
||||
}
|
||||
|
||||
tracing::info!("Starting MCP server");
|
||||
|
||||
let router: Option<Box<dyn BoundedService>> = match name {
|
||||
"developer" => Some(Box::new(RouterService(DeveloperRouter::new()))),
|
||||
"computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))),
|
||||
"google_drive" | "googledrive" => {
|
||||
let router = GoogleDriveRouter::new().await;
|
||||
Some(Box::new(RouterService(router)))
|
||||
}
|
||||
"memory" => Some(Box::new(RouterService(MemoryRouter::new()))),
|
||||
"tutorial" => Some(Box::new(RouterService(TutorialRouter::new()))),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Create shutdown notification channel
|
||||
let shutdown = Arc::new(Notify::new());
|
||||
let shutdown_clone = shutdown.clone();
|
||||
|
||||
// Spawn shutdown signal handler
|
||||
tokio::spawn(async move {
|
||||
crate::signal::shutdown_signal().await;
|
||||
shutdown_clone.notify_one();
|
||||
});
|
||||
|
||||
// Create and run the server
|
||||
let server = Server::new(router.unwrap_or_else(|| panic!("Unknown server requested {}", name)));
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
|
||||
|
||||
@@ -38,10 +38,6 @@ chrono = { version = "0.4.38", features = ["serde"] }
|
||||
etcetera = "0.8.0"
|
||||
tempfile = "3.8"
|
||||
include_dir = "0.7.4"
|
||||
google-apis-common = "7.0.0"
|
||||
google-drive3 = "6.0.0"
|
||||
google-sheets4 = "6.0.0"
|
||||
google-docs1 = "6.0.0"
|
||||
webbrowser = "0.8"
|
||||
http-body-util = "0.1.2"
|
||||
regex = "1.11.1"
|
||||
|
||||
@@ -1,478 +0,0 @@
|
||||
#![allow(clippy::ptr_arg, dead_code, clippy::enum_variant_names)]
|
||||
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
|
||||
use google_apis_common as common;
|
||||
use tokio::time::sleep;
|
||||
|
||||
/// A scope is needed when requesting an
|
||||
/// [authorization token](https://developers.google.com/workspace/drive/labels/guides/authorize).
|
||||
#[derive(PartialEq, Eq, Ord, PartialOrd, Hash, Debug, Clone, Copy)]
|
||||
pub enum Scope {
|
||||
/// View, use, and manage Drive labels.
|
||||
DriveLabels,
|
||||
|
||||
/// View and use Drive labels.
|
||||
DriveLabelsReadonly,
|
||||
|
||||
/// View, edit, create, and delete all Drive labels in your organization,
|
||||
/// and view your organization's label-related administration policies.
|
||||
DriveLabelsAdmin,
|
||||
|
||||
/// View all Drive labels and label-related administration policies in your
|
||||
/// organization.
|
||||
DriveLabelsAdminReadonly,
|
||||
}
|
||||
|
||||
impl AsRef<str> for Scope {
|
||||
fn as_ref(&self) -> &str {
|
||||
match *self {
|
||||
Scope::DriveLabels => "https://www.googleapis.com/auth/drive.labels",
|
||||
Scope::DriveLabelsReadonly => "https://www.googleapis.com/auth/drive.labels.readonly",
|
||||
Scope::DriveLabelsAdmin => "https://www.googleapis.com/auth/drive.admin.labels",
|
||||
Scope::DriveLabelsAdminReadonly => {
|
||||
"https://www.googleapis.com/auth/drive.admin.labels.readonly"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)]
|
||||
impl Default for Scope {
|
||||
fn default() -> Scope {
|
||||
Scope::DriveLabelsReadonly
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DriveLabelsHub<C> {
|
||||
pub client: common::Client<C>,
|
||||
pub auth: Box<dyn common::GetToken>,
|
||||
_user_agent: String,
|
||||
_base_url: String,
|
||||
}
|
||||
|
||||
impl<C> common::Hub for DriveLabelsHub<C> {}
|
||||
|
||||
impl<'a, C> DriveLabelsHub<C> {
|
||||
pub fn new<A: 'static + common::GetToken>(
|
||||
client: common::Client<C>,
|
||||
auth: A,
|
||||
) -> DriveLabelsHub<C> {
|
||||
DriveLabelsHub {
|
||||
client,
|
||||
auth: Box::new(auth),
|
||||
_user_agent: "google-api-rust-client/6.0.0".to_string(),
|
||||
_base_url: "https://drivelabels.googleapis.com/".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn labels(&'a self) -> LabelMethods<'a, C> {
|
||||
LabelMethods { hub: self }
|
||||
}
|
||||
|
||||
/// Set the user-agent header field to use in all requests to the server.
|
||||
/// It defaults to `google-api-rust-client/6.0.0`.
|
||||
///
|
||||
/// Returns the previously set user-agent.
|
||||
pub fn user_agent(&mut self, agent_name: String) -> String {
|
||||
std::mem::replace(&mut self._user_agent, agent_name)
|
||||
}
|
||||
|
||||
/// Set the base url to use in all requests to the server.
|
||||
/// It defaults to `https://www.googleapis.com/drive/v3/`.
|
||||
///
|
||||
/// Returns the previously set base url.
|
||||
pub fn base_url(&mut self, new_base_url: String) -> String {
|
||||
std::mem::replace(&mut self._base_url, new_base_url)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Label {
|
||||
#[serde(rename = "name")]
|
||||
pub name: Option<String>,
|
||||
#[serde(rename = "id")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "revisionId")]
|
||||
pub revision_id: Option<String>,
|
||||
#[serde(rename = "labelType")]
|
||||
pub label_type: Option<String>,
|
||||
#[serde(rename = "creator")]
|
||||
pub creator: Option<User>,
|
||||
#[serde(rename = "createTime")]
|
||||
pub create_time: Option<String>,
|
||||
#[serde(rename = "revisionCreator")]
|
||||
pub revision_creator: Option<User>,
|
||||
#[serde(rename = "revisionCreateTime")]
|
||||
pub revision_create_time: Option<String>,
|
||||
#[serde(rename = "publisher")]
|
||||
pub publisher: Option<User>,
|
||||
#[serde(rename = "publishTime")]
|
||||
pub publish_time: Option<String>,
|
||||
#[serde(rename = "disabler")]
|
||||
pub disabler: Option<User>,
|
||||
#[serde(rename = "disableTime")]
|
||||
pub disable_time: Option<String>,
|
||||
#[serde(rename = "customer")]
|
||||
pub customer: Option<String>,
|
||||
pub properties: Option<LabelProperty>,
|
||||
pub fields: Option<Vec<Field>>,
|
||||
// We ignore the remaining fields.
|
||||
}
|
||||
|
||||
impl common::Part for Label {}
|
||||
|
||||
impl common::ResponseResult for Label {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LabelProperty {
|
||||
pub title: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for LabelProperty {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Field {
|
||||
id: Option<String>,
|
||||
#[serde(rename = "queryKey")]
|
||||
query_key: Option<String>,
|
||||
properties: Option<FieldProperty>,
|
||||
#[serde(rename = "selectionOptions")]
|
||||
selection_options: Option<SelectionOption>,
|
||||
}
|
||||
|
||||
impl common::Part for Field {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct FieldProperty {
|
||||
#[serde(rename = "displayName")]
|
||||
pub display_name: Option<String>,
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
impl common::Part for FieldProperty {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SelectionOption {
|
||||
#[serde(rename = "listOptions")]
|
||||
pub list_options: Option<String>,
|
||||
pub choices: Option<Vec<Choice>>,
|
||||
}
|
||||
|
||||
impl common::Part for SelectionOption {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Choice {
|
||||
id: Option<String>,
|
||||
properties: Option<ChoiceProperties>,
|
||||
// We ignore the remaining fields.
|
||||
}
|
||||
|
||||
impl common::Part for Choice {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ChoiceProperties {
|
||||
#[serde(rename = "displayName")]
|
||||
display_name: Option<String>,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for ChoiceProperties {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LabelList {
|
||||
pub labels: Option<Vec<Label>>,
|
||||
#[serde(rename = "nextPageToken")]
|
||||
pub next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
impl common::ResponseResult for LabelList {}
|
||||
|
||||
/// Information about a Drive user.
|
||||
///
|
||||
/// This type is not used in any activity, and only used as *part* of another schema.
|
||||
///
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct User {
|
||||
/// Output only. A plain text displayable name for this user.
|
||||
#[serde(rename = "displayName")]
|
||||
pub display_name: Option<String>,
|
||||
/// Output only. The email address of the user. This may not be present in certain contexts if the user has not made their email address visible to the requester.
|
||||
#[serde(rename = "emailAddress")]
|
||||
pub email_address: Option<String>,
|
||||
/// Output only. Identifies what kind of resource this is. Value: the fixed string `"drive#user"`.
|
||||
pub kind: Option<String>,
|
||||
/// Output only. Whether this user is the requesting user.
|
||||
pub me: Option<bool>,
|
||||
/// Output only. The user's ID as visible in Permission resources.
|
||||
#[serde(rename = "permissionId")]
|
||||
pub permission_id: Option<String>,
|
||||
/// Output only. A link to the user's profile photo, if available.
|
||||
#[serde(rename = "photoLink")]
|
||||
pub photo_link: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for User {}
|
||||
|
||||
pub struct LabelMethods<'a, C>
|
||||
where
|
||||
C: 'a,
|
||||
{
|
||||
hub: &'a DriveLabelsHub<C>,
|
||||
}
|
||||
|
||||
impl<C> common::MethodsBuilder for LabelMethods<'_, C> {}
|
||||
|
||||
impl<'a, C> LabelMethods<'a, C> {
|
||||
/// Create a builder to help you perform the following tasks:
|
||||
///
|
||||
/// List labels
|
||||
pub fn list(&self) -> LabelListCall<'a, C> {
|
||||
LabelListCall {
|
||||
hub: self.hub,
|
||||
_delegate: Default::default(),
|
||||
_additional_params: Default::default(),
|
||||
_scopes: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lists the workspace's labels.
|
||||
pub struct LabelListCall<'a, C>
|
||||
where
|
||||
C: 'a,
|
||||
{
|
||||
hub: &'a DriveLabelsHub<C>,
|
||||
_delegate: Option<&'a mut dyn common::Delegate>,
|
||||
_additional_params: HashMap<String, String>,
|
||||
_scopes: BTreeSet<String>,
|
||||
}
|
||||
|
||||
impl<C> common::CallBuilder for LabelListCall<'_, C> {}
|
||||
|
||||
impl<'a, C> LabelListCall<'a, C>
|
||||
where
|
||||
C: common::Connector,
|
||||
{
|
||||
/// Perform the operation you have built so far.
|
||||
pub async fn doit(mut self) -> common::Result<(common::Response, LabelList)> {
|
||||
use common::url::Params;
|
||||
use hyper::header::{AUTHORIZATION, CONTENT_LENGTH, USER_AGENT};
|
||||
|
||||
let mut dd = common::DefaultDelegate;
|
||||
let dlg: &mut dyn common::Delegate = self._delegate.unwrap_or(&mut dd);
|
||||
dlg.begin(common::MethodInfo {
|
||||
id: "drivelabels.labels.list",
|
||||
http_method: hyper::Method::GET,
|
||||
});
|
||||
|
||||
for &field in ["alt"].iter() {
|
||||
if self._additional_params.contains_key(field) {
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::FieldClash(field));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We don't handle any of the query params.
|
||||
let mut params = Params::with_capacity(2 + self._additional_params.len());
|
||||
|
||||
params.extend(self._additional_params.iter());
|
||||
|
||||
params.push("alt", "json");
|
||||
let url = self.hub._base_url.clone() + "v2/labels";
|
||||
|
||||
if self._scopes.is_empty() {
|
||||
self._scopes
|
||||
.insert(Scope::DriveLabelsReadonly.as_ref().to_string());
|
||||
}
|
||||
|
||||
let url = params.parse_with_url(&url);
|
||||
|
||||
loop {
|
||||
let token = match self
|
||||
.hub
|
||||
.auth
|
||||
.get_token(&self._scopes.iter().map(String::as_str).collect::<Vec<_>>()[..])
|
||||
.await
|
||||
{
|
||||
Ok(token) => token,
|
||||
Err(e) => match dlg.token(e) {
|
||||
Ok(token) => token,
|
||||
Err(e) => {
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::MissingToken(e));
|
||||
}
|
||||
},
|
||||
};
|
||||
let req_result = {
|
||||
let client = &self.hub.client;
|
||||
dlg.pre_request();
|
||||
let mut req_builder = hyper::Request::builder()
|
||||
.method(hyper::Method::GET)
|
||||
.uri(url.as_str())
|
||||
.header(USER_AGENT, self.hub._user_agent.clone());
|
||||
|
||||
if let Some(token) = token.as_ref() {
|
||||
req_builder = req_builder.header(AUTHORIZATION, format!("Bearer {}", token));
|
||||
}
|
||||
|
||||
let request = req_builder
|
||||
.header(CONTENT_LENGTH, 0_u64)
|
||||
.body(common::to_body::<String>(None));
|
||||
client.request(request.unwrap()).await
|
||||
};
|
||||
|
||||
match req_result {
|
||||
Err(err) => {
|
||||
if let common::Retry::After(d) = dlg.http_error(&err) {
|
||||
sleep(d).await;
|
||||
continue;
|
||||
}
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::HttpError(err));
|
||||
}
|
||||
Ok(res) => {
|
||||
let (parts, body) = res.into_parts();
|
||||
let body = common::Body::new(body);
|
||||
if !parts.status.is_success() {
|
||||
let bytes = common::to_bytes(body).await.unwrap_or_default();
|
||||
let error = serde_json::from_str(&common::to_string(&bytes));
|
||||
let response = common::to_response(parts, bytes.into());
|
||||
|
||||
if let common::Retry::After(d) =
|
||||
dlg.http_failure(&response, error.as_ref().ok())
|
||||
{
|
||||
sleep(d).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
dlg.finished(false);
|
||||
|
||||
return Err(match error {
|
||||
Ok(value) => common::Error::BadRequest(value),
|
||||
_ => common::Error::Failure(response),
|
||||
});
|
||||
}
|
||||
let response = {
|
||||
let bytes = common::to_bytes(body).await.unwrap_or_default();
|
||||
let encoded = common::to_string(&bytes);
|
||||
match serde_json::from_str(&encoded) {
|
||||
Ok(decoded) => (common::to_response(parts, bytes.into()), decoded),
|
||||
Err(error) => {
|
||||
dlg.response_json_decode_error(&encoded, &error);
|
||||
return Err(common::Error::JsonDecodeError(
|
||||
encoded.to_string(),
|
||||
error,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
dlg.finished(true);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The delegate implementation is consulted whenever there is an intermediate result, or if something goes wrong
|
||||
/// while executing the actual API request.
|
||||
///
|
||||
/// ````text
|
||||
/// It should be used to handle progress information, and to implement a certain level of resilience.
|
||||
/// ````
|
||||
///
|
||||
/// Sets the *delegate* property to the given value.
|
||||
pub fn delegate(mut self, new_value: &'a mut dyn common::Delegate) -> LabelListCall<'a, C> {
|
||||
self._delegate = Some(new_value);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set any additional parameter of the query string used in the request.
|
||||
/// It should be used to set parameters which are not yet available through their own
|
||||
/// setters.
|
||||
///
|
||||
/// Please note that this method must not be used to set any of the known parameters
|
||||
/// which have their own setter method. If done anyway, the request will fail.
|
||||
///
|
||||
/// # Additional Parameters
|
||||
///
|
||||
/// * *$.xgafv* (query-string) - V1 error format.
|
||||
/// * *access_token* (query-string) - OAuth access token.
|
||||
/// * *alt* (query-string) - Data format for response.
|
||||
/// * *callback* (query-string) - JSONP
|
||||
/// * *fields* (query-string) - Selector specifying which fields to include in a partial response.
|
||||
/// * *key* (query-string) - API key. Your API key identifies your project and provides you with API access, quota, and reports. Required unless you provide an OAuth 2.0 token.
|
||||
/// * *oauth_token* (query-string) - OAuth 2.0 token for the current user.
|
||||
/// * *prettyPrint* (query-boolean) - Returns response with indentations and line breaks.
|
||||
/// * *quotaUser* (query-string) - Available to use for quota purposes for server-side applications. Can be any arbitrary string assigned to a user, but should not exceed 40 characters.
|
||||
/// * *uploadType* (query-string) - Legacy upload protocol for media (e.g. "media", "multipart").
|
||||
/// * *upload_protocol* (query-string) - Upload protocol for media (e.g. "raw", "multipart").
|
||||
pub fn param<T>(mut self, name: T, value: T) -> LabelListCall<'a, C>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
self._additional_params
|
||||
.insert(name.as_ref().to_string(), value.as_ref().to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Identifies the authorization scope for the method you are building.
|
||||
///
|
||||
/// Use this method to actively specify which scope should be used, instead of the default [`Scope`] variant
|
||||
/// [`Scope::DriveLabelsReadonly`].
|
||||
///
|
||||
/// The `scope` will be added to a set of scopes. This is important as one can maintain access
|
||||
/// tokens for more than one scope.
|
||||
///
|
||||
/// Usually there is more than one suitable scope to authorize an operation, some of which may
|
||||
/// encompass more rights than others. For example, for listing resources, a *read-only* scope will be
|
||||
/// sufficient, a read-write scope will do as well.
|
||||
pub fn add_scope<St>(mut self, scope: St) -> LabelListCall<'a, C>
|
||||
where
|
||||
St: AsRef<str>,
|
||||
{
|
||||
self._scopes.insert(String::from(scope.as_ref()));
|
||||
self
|
||||
}
|
||||
/// Identifies the authorization scope(s) for the method you are building.
|
||||
///
|
||||
/// See [`Self::add_scope()`] for details.
|
||||
pub fn add_scopes<I, St>(mut self, scopes: I) -> LabelListCall<'a, C>
|
||||
where
|
||||
I: IntoIterator<Item = St>,
|
||||
St: AsRef<str>,
|
||||
{
|
||||
self._scopes
|
||||
.extend(scopes.into_iter().map(|s| String::from(s.as_ref())));
|
||||
self
|
||||
}
|
||||
|
||||
/// Removes all scopes, and no default scope will be used either.
|
||||
/// In this case, you have to specify your API-key using the `key` parameter (see [`Self::param()`]
|
||||
/// for details).
|
||||
pub fn clear_scopes(mut self) -> LabelListCall<'a, C> {
|
||||
self._scopes.clear();
|
||||
self
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,386 +0,0 @@
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::future::Future;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use google_drive3::common::GetToken;
|
||||
use oauth2::basic::BasicClient;
|
||||
use oauth2::reqwest;
|
||||
use oauth2::{
|
||||
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet,
|
||||
PkceCodeChallenge, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, error, info};
|
||||
use url::Url;
|
||||
|
||||
use super::storage::CredentialsManager;
|
||||
|
||||
/// Structure representing the OAuth2 configuration file format
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OAuth2Config {
|
||||
installed: InstalledConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct InstalledConfig {
|
||||
client_id: String,
|
||||
project_id: String,
|
||||
auth_uri: String,
|
||||
token_uri: String,
|
||||
auth_provider_x509_cert_url: String,
|
||||
client_secret: String,
|
||||
redirect_uris: Vec<String>,
|
||||
}
|
||||
|
||||
/// Structure for token storage
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct TokenData {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
expires_at: Option<u64>,
|
||||
project_id: String,
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
/// PkceOAuth2Client implements the GetToken trait required by DriveHub
|
||||
/// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow
|
||||
#[derive(Clone)]
|
||||
pub struct PkceOAuth2Client {
|
||||
client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
|
||||
credentials_manager: Arc<CredentialsManager>,
|
||||
http_client: reqwest::Client,
|
||||
project_id: String,
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
impl PkceOAuth2Client {
|
||||
pub fn new(
|
||||
config_path: impl AsRef<Path>,
|
||||
credentials_manager: Arc<CredentialsManager>,
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync>> {
|
||||
// Load and parse the config file
|
||||
let config_content = fs::read_to_string(config_path)?;
|
||||
let config: OAuth2Config = serde_json::from_str(&config_content)?;
|
||||
|
||||
// Extract the project_id from the config
|
||||
let project_id = config.installed.project_id.clone();
|
||||
let scopes = vec![];
|
||||
|
||||
// Create OAuth URLs
|
||||
let auth_url =
|
||||
AuthUrl::new(config.installed.auth_uri).expect("Invalid authorization endpoint URL");
|
||||
let token_url =
|
||||
TokenUrl::new(config.installed.token_uri).expect("Invalid token endpoint URL");
|
||||
|
||||
// Set up the OAuth2 client
|
||||
let client = BasicClient::new(ClientId::new(config.installed.client_id))
|
||||
.set_client_secret(ClientSecret::new(config.installed.client_secret))
|
||||
.set_auth_uri(auth_url)
|
||||
.set_token_uri(token_url)
|
||||
.set_redirect_uri(
|
||||
RedirectUrl::new("http://localhost:18080".to_string())
|
||||
.expect("Invalid redirect URL"),
|
||||
);
|
||||
|
||||
let http_client = reqwest::ClientBuilder::new()
|
||||
// Following redirects opens the client up to SSRF vulnerabilities.
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Oauth2 HTTP Client should build");
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
credentials_manager,
|
||||
http_client,
|
||||
project_id,
|
||||
scopes,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a token is expired or about to expire within the buffer period
|
||||
fn is_token_expired(&self, expires_at: Option<u64>, buffer_seconds: u64) -> bool {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_secs();
|
||||
|
||||
// Consider the token expired if it's within buffer_seconds of expiring
|
||||
// This gives us a safety margin to avoid using tokens right before expiration
|
||||
expires_at
|
||||
.map(|expiry_time| now + buffer_seconds >= expiry_time)
|
||||
.unwrap_or(true) // If we don't know when it expires, assume it's expired to be safe
|
||||
}
|
||||
|
||||
async fn perform_oauth_flow(
|
||||
&self,
|
||||
scopes: &[&str],
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
// Create a PKCE code verifier and challenge
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
// Generate the authorization URL
|
||||
let (auth_url, csrf_token) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(scopes.iter().map(|&s| Scope::new(s.to_string())))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
|
||||
info!("Opening browser for OAuth2 authentication");
|
||||
if let Err(e) = webbrowser::open(auth_url.as_str()) {
|
||||
error!("Failed to open browser: {}", e);
|
||||
println!("Please open this URL in your browser:\n{}\n", auth_url);
|
||||
}
|
||||
|
||||
// Start a local server to receive the authorization code
|
||||
// We'll spawn this in a separate thread since it's blocking
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
std::thread::spawn(move || match Self::start_redirect_server() {
|
||||
Ok(result) => {
|
||||
let _ = tx.send(Ok(result));
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e));
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for the code from the redirect server
|
||||
let (code, received_state) = rx.await??;
|
||||
|
||||
// Verify the CSRF state
|
||||
if received_state.secret() != csrf_token.secret() {
|
||||
return Err("CSRF token mismatch".into());
|
||||
}
|
||||
|
||||
// Use the built-in exchange_code method with PKCE verifier
|
||||
let token_result = self
|
||||
.client
|
||||
.exchange_code(code)
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request_async(&self.http_client)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)?;
|
||||
|
||||
let access_token = token_result.access_token().secret().clone();
|
||||
|
||||
// Calculate expires_at as a Unix timestamp by adding expires_in to current time
|
||||
let expires_at = token_result.expires_in().map(|duration| {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_secs();
|
||||
now + duration.as_secs()
|
||||
});
|
||||
|
||||
// Get the refresh token if provided
|
||||
if let Some(refresh_token) = token_result.refresh_token() {
|
||||
let refresh_token_str = refresh_token.secret().clone();
|
||||
|
||||
// Store token data
|
||||
let token_data = TokenData {
|
||||
access_token: access_token.clone(),
|
||||
refresh_token: refresh_token_str.clone(),
|
||||
expires_at,
|
||||
project_id: self.project_id.clone(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
};
|
||||
|
||||
// Store updated token data
|
||||
self.credentials_manager
|
||||
.write_credentials(&token_data)
|
||||
.map(|_| debug!("Successfully stored token data"))
|
||||
.unwrap_or_else(|e| error!("Failed to store token data: {}", e));
|
||||
} else {
|
||||
debug!("No refresh token provided in OAuth flow response");
|
||||
}
|
||||
|
||||
Ok(access_token)
|
||||
}
|
||||
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
refresh_token: &str,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
debug!("Attempting to refresh access token");
|
||||
|
||||
// Create a RefreshToken from the string
|
||||
let refresh_token = RefreshToken::new(refresh_token.to_string());
|
||||
|
||||
// Use the built-in exchange_refresh_token method
|
||||
let token_result = self
|
||||
.client
|
||||
.exchange_refresh_token(&refresh_token)
|
||||
.request_async(&self.http_client)
|
||||
.await
|
||||
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)?;
|
||||
|
||||
let access_token = token_result.access_token().secret().clone();
|
||||
|
||||
// Calculate expires_at as a Unix timestamp by adding expires_in to current time
|
||||
let expires_at = token_result.expires_in().map(|duration| {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_secs();
|
||||
now + duration.as_secs()
|
||||
});
|
||||
|
||||
// Get the refresh token - either the new one or reuse the existing one
|
||||
let new_refresh_token = token_result
|
||||
.refresh_token()
|
||||
.map(|token| token.secret().clone())
|
||||
.unwrap_or_else(|| refresh_token.secret().to_string());
|
||||
|
||||
// Always update the token data with the new access token and expiration
|
||||
let token_data = TokenData {
|
||||
access_token: access_token.clone(),
|
||||
refresh_token: new_refresh_token.clone(),
|
||||
expires_at,
|
||||
project_id: self.project_id.clone(),
|
||||
scopes: self.scopes.clone(),
|
||||
};
|
||||
|
||||
// Store updated token data
|
||||
self.credentials_manager
|
||||
.write_credentials(&token_data)
|
||||
.map(|_| debug!("Successfully stored token data"))
|
||||
.unwrap_or_else(|e| error!("Failed to store token data: {}", e));
|
||||
|
||||
Ok(access_token)
|
||||
}
|
||||
|
||||
fn start_redirect_server(
|
||||
) -> Result<(AuthorizationCode, CsrfToken), Box<dyn Error + Send + Sync>> {
|
||||
let listener = TcpListener::bind("127.0.0.1:18080")?;
|
||||
println!("Listening for the authorization code on http://localhost:18080");
|
||||
|
||||
for stream in listener.incoming() {
|
||||
match stream {
|
||||
Ok(mut stream) => {
|
||||
let mut reader = BufReader::new(&stream);
|
||||
let mut request_line = String::new();
|
||||
reader.read_line(&mut request_line)?;
|
||||
|
||||
let redirect_url = request_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.ok_or("Invalid request")?;
|
||||
|
||||
let url = Url::parse(&format!("http://localhost{}", redirect_url))?;
|
||||
|
||||
let code = url
|
||||
.query_pairs()
|
||||
.find(|(key, _)| key == "code")
|
||||
.map(|(_, value)| AuthorizationCode::new(value.into_owned()))
|
||||
.ok_or("No code found in the response")?;
|
||||
|
||||
let state = url
|
||||
.query_pairs()
|
||||
.find(|(key, _)| key == "state")
|
||||
.map(|(_, value)| CsrfToken::new(value.into_owned()))
|
||||
.ok_or("No state found in the response")?;
|
||||
|
||||
// Send a success response to the browser
|
||||
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
|
||||
<html><body><h1>Authentication successful!</h1>\
|
||||
<p>You can now close this window and return to the application.</p></body></html>";
|
||||
|
||||
stream.write_all(response.as_bytes())?;
|
||||
stream.flush()?;
|
||||
|
||||
return Ok((code, state));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to accept connection: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err("Failed to receive authorization code".into())
|
||||
}
|
||||
}
|
||||
|
||||
// impl GetToken for use with DriveHub directly
|
||||
// see google_drive3::common::GetToken
|
||||
impl GetToken for PkceOAuth2Client {
|
||||
fn get_token<'a>(
|
||||
&'a self,
|
||||
scopes: &'a [&str],
|
||||
) -> Pin<
|
||||
Box<dyn Future<Output = Result<Option<String>, Box<dyn Error + Send + Sync>>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
// Try to read token data from storage to check if we have a valid token
|
||||
if let Ok(token_data) = self.credentials_manager.read_credentials::<TokenData>() {
|
||||
// Verify the project_id matches
|
||||
if token_data.project_id == self.project_id {
|
||||
// Convert stored scopes to &str slices for comparison
|
||||
let stored_scope_refs: Vec<&str> =
|
||||
token_data.scopes.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
// Check if we need additional scopes
|
||||
let needs_additional_scopes = scopes.iter().any(|&scope| {
|
||||
!stored_scope_refs
|
||||
.iter()
|
||||
.any(|&stored| stored.contains(scope))
|
||||
});
|
||||
|
||||
if !needs_additional_scopes {
|
||||
// Check if the token is expired or expiring within a 5-min buffer
|
||||
if !self.is_token_expired(token_data.expires_at, 300) {
|
||||
return Ok(Some(token_data.access_token));
|
||||
}
|
||||
|
||||
// Token is expired or will expire soon, try to refresh it
|
||||
debug!("Token is expired or will expire soon, refreshing...");
|
||||
|
||||
// Try to refresh the token
|
||||
if let Ok(access_token) =
|
||||
self.refresh_token(&token_data.refresh_token).await
|
||||
{
|
||||
debug!("Successfully refreshed access token");
|
||||
return Ok(Some(access_token));
|
||||
}
|
||||
} else {
|
||||
// Only allocate new strings when we need to combine scopes
|
||||
let mut combined_scopes: Vec<&str> =
|
||||
Vec::with_capacity(scopes.len() + stored_scope_refs.len());
|
||||
combined_scopes.extend(scopes);
|
||||
combined_scopes.extend(stored_scope_refs.iter().filter(|&&stored| {
|
||||
!scopes.iter().any(|&scope| stored.contains(scope))
|
||||
}));
|
||||
|
||||
return self
|
||||
.perform_oauth_flow(&combined_scopes)
|
||||
.await
|
||||
.map(Some)
|
||||
.map_err(|e| {
|
||||
error!("OAuth flow failed: {}", e);
|
||||
e
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we get here, either:
|
||||
// 1. The project ID didn't match
|
||||
// 2. Token refresh failed
|
||||
// 3. There are no valid tokens yet
|
||||
// 4. We didn't have to change the scopes of an existing token
|
||||
// Fallback: perform interactive OAuth flow
|
||||
self.perform_oauth_flow(scopes)
|
||||
.await
|
||||
.map(Some)
|
||||
.map_err(|e| {
|
||||
error!("OAuth flow failed: {}", e);
|
||||
e
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use keyring::Entry;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum StorageError {
|
||||
#[error("Failed to access keychain: {0}")]
|
||||
KeyringError(#[from] keyring::Error),
|
||||
#[error("Failed to access file system: {0}")]
|
||||
FileSystemError(#[from] std::io::Error),
|
||||
#[error("No credentials found")]
|
||||
NotFound,
|
||||
#[error("Critical error: {0}")]
|
||||
Critical(String),
|
||||
#[error("Failed to serialize/deserialize: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
/// CredentialsManager handles secure storage of OAuth credentials.
|
||||
/// It attempts to store credentials in the system keychain first,
|
||||
/// with fallback to file system storage if keychain access fails and fallback is enabled.
|
||||
pub struct CredentialsManager {
|
||||
credentials_path: String,
|
||||
fallback_to_disk: bool,
|
||||
keychain_service: String,
|
||||
keychain_username: String,
|
||||
}
|
||||
|
||||
impl CredentialsManager {
|
||||
pub fn new(
|
||||
credentials_path: String,
|
||||
fallback_to_disk: bool,
|
||||
keychain_service: String,
|
||||
keychain_username: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
credentials_path,
|
||||
fallback_to_disk,
|
||||
keychain_service,
|
||||
keychain_username,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads and deserializes credentials from secure storage.
|
||||
///
|
||||
/// This method attempts to read credentials from the system keychain first.
|
||||
/// If keychain access fails and fallback is enabled, it will try to read from the file system.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the credentials into. Must implement `serde::de::DeserializeOwned`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(T)` - The deserialized credentials
|
||||
/// * `Err(StorageError)` - If reading or deserialization fails
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use goose_mcp::google_drive::storage::CredentialsManager;
|
||||
/// use serde::{Serialize, Deserialize};
|
||||
///
|
||||
/// #[derive(Serialize, Deserialize)]
|
||||
/// struct OAuthToken {
|
||||
/// access_token: String,
|
||||
/// refresh_token: String,
|
||||
/// expiry: u64,
|
||||
/// }
|
||||
///
|
||||
/// let manager = CredentialsManager::new(
|
||||
/// String::from("/path/to/credentials.json"),
|
||||
/// true, // fallback to disk if keychain fails
|
||||
/// String::from("test_service"),
|
||||
/// String::from("test_user")
|
||||
/// );
|
||||
/// match manager.read_credentials::<OAuthToken>() {
|
||||
/// Ok(token) => println!("Token expires at: {}", token.expiry),
|
||||
/// Err(e) => eprintln!("Failed to read token: {}", e),
|
||||
/// }
|
||||
/// ```
|
||||
pub fn read_credentials<T>(&self) -> Result<T, StorageError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let json_str = Entry::new(&self.keychain_service, &self.keychain_username)
|
||||
.and_then(|entry| entry.get_password())
|
||||
.inspect(|_| {
|
||||
debug!("Successfully read credentials from keychain");
|
||||
})
|
||||
.or_else(|e| {
|
||||
if self.fallback_to_disk {
|
||||
debug!("Falling back to file system due to keyring error: {}", e);
|
||||
self.read_from_file()
|
||||
} else {
|
||||
match e {
|
||||
keyring::Error::NoEntry => Err(StorageError::NotFound),
|
||||
_ => Err(StorageError::KeyringError(e)),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
serde_json::from_str(&json_str).map_err(StorageError::SerializationError)
|
||||
}
|
||||
|
||||
fn read_from_file(&self) -> Result<String, StorageError> {
|
||||
let path = Path::new(&self.credentials_path);
|
||||
if path.exists() {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(content) => {
|
||||
debug!("Successfully read credentials from file system");
|
||||
Ok(content)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read credentials file: {}", e);
|
||||
Err(StorageError::FileSystemError(e))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("No credentials found in file system");
|
||||
Err(StorageError::NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes and writes credentials to secure storage.
|
||||
///
|
||||
/// This method attempts to write credentials to the system keychain first.
|
||||
/// If keychain access fails and fallback is enabled, it will try to write to the file system.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to serialize. Must implement `serde::Serialize`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `content` - The data to serialize and store
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(())` - If writing succeeds
|
||||
/// * `Err(StorageError)` - If serialization or writing fails
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use goose_mcp::google_drive::storage::CredentialsManager;
|
||||
/// use serde::{Serialize, Deserialize};
|
||||
///
|
||||
/// #[derive(Serialize, Deserialize)]
|
||||
/// struct OAuthToken {
|
||||
/// access_token: String,
|
||||
/// refresh_token: String,
|
||||
/// expiry: u64,
|
||||
/// }
|
||||
///
|
||||
/// let token = OAuthToken {
|
||||
/// access_token: String::from("access_token_value"),
|
||||
/// refresh_token: String::from("refresh_token_value"),
|
||||
/// expiry: 1672531200, // Unix timestamp
|
||||
/// };
|
||||
///
|
||||
/// let manager = CredentialsManager::new(
|
||||
/// String::from("/path/to/credentials.json"),
|
||||
/// true, // fallback to disk if keychain fails
|
||||
/// String::from("test_service"),
|
||||
/// String::from("test_user")
|
||||
/// );
|
||||
/// if let Err(e) = manager.write_credentials(&token) {
|
||||
/// eprintln!("Failed to write token: {}", e);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn write_credentials<T>(&self, content: &T) -> Result<(), StorageError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let json_str = serde_json::to_string(content).map_err(StorageError::SerializationError)?;
|
||||
|
||||
Entry::new(&self.keychain_service, &self.keychain_username)
|
||||
.and_then(|entry| entry.set_password(&json_str))
|
||||
.inspect(|_| {
|
||||
debug!("Successfully wrote credentials to keychain");
|
||||
})
|
||||
.or_else(|e| {
|
||||
if self.fallback_to_disk {
|
||||
warn!("Falling back to file system due to keyring error: {}", e);
|
||||
self.write_to_file(&json_str)
|
||||
} else {
|
||||
Err(StorageError::KeyringError(e))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn write_to_file(&self, content: &str) -> Result<(), StorageError> {
|
||||
let path = Path::new(&self.credentials_path);
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.exists() {
|
||||
match fs::create_dir_all(parent) {
|
||||
Ok(_) => debug!("Created parent directories for credentials file"),
|
||||
Err(e) => {
|
||||
error!("Failed to create directories for credentials file: {}", e);
|
||||
return Err(StorageError::FileSystemError(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match fs::write(path, content) {
|
||||
Ok(_) => {
|
||||
debug!("Successfully wrote credentials to file system");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to write credentials to file system: {}", e);
|
||||
Err(StorageError::FileSystemError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for CredentialsManager {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
credentials_path: self.credentials_path.clone(),
|
||||
fallback_to_disk: self.fallback_to_disk,
|
||||
keychain_service: self.keychain_service.clone(),
|
||||
keychain_username: self.keychain_username.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct TestCredentials {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
expiry: u64,
|
||||
}
|
||||
|
||||
impl TestCredentials {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
access_token: "test_access_token".to_string(),
|
||||
refresh_token: "test_refresh_token".to_string(),
|
||||
expiry: 1672531200,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_write_from_keychain() {
|
||||
// Create a temporary directory for test files
|
||||
let temp_dir = tempdir().expect("Failed to create temp dir");
|
||||
let cred_path = temp_dir.path().join("test_credentials.json");
|
||||
let cred_path_str = cred_path.to_str().unwrap().to_string();
|
||||
|
||||
// Create a credentials manager with fallback enabled
|
||||
// Using a unique service name to ensure keychain operation fails
|
||||
let manager = CredentialsManager::new(
|
||||
cred_path_str,
|
||||
true, // fallback to disk
|
||||
"test_service".to_string(),
|
||||
"test_user".to_string(),
|
||||
);
|
||||
|
||||
// Test credentials to store
|
||||
let creds = TestCredentials::new();
|
||||
|
||||
// Write should write to keychain
|
||||
let write_result = manager.write_credentials(&creds);
|
||||
assert!(write_result.is_ok(), "Write should succeed with fallback");
|
||||
|
||||
// Read should read from keychain
|
||||
let read_result = manager.read_credentials::<TestCredentials>();
|
||||
assert!(read_result.is_ok(), "Read should succeed with fallback");
|
||||
|
||||
// Verify the read credentials match what we wrote
|
||||
assert_eq!(
|
||||
read_result.unwrap(),
|
||||
creds,
|
||||
"Read credentials should match written credentials"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_fallback_not_found() {
|
||||
// Create a temporary directory for test files
|
||||
let temp_dir = tempdir().expect("Failed to create temp dir");
|
||||
let cred_path = temp_dir.path().join("nonexistent_credentials.json");
|
||||
let cred_path_str = cred_path.to_str().unwrap().to_string();
|
||||
|
||||
// Create a credentials manager with fallback disabled
|
||||
let manager = CredentialsManager::new(
|
||||
cred_path_str,
|
||||
false, // no fallback to disk
|
||||
"test_service_that_should_not_exist".to_string(),
|
||||
"test_user_no_fallback".to_string(),
|
||||
);
|
||||
|
||||
// Read should fail with NotFound or KeyringError depending on the system
|
||||
let read_result = manager.read_credentials::<TestCredentials>();
|
||||
println!("{:?}", read_result);
|
||||
assert!(
|
||||
read_result.is_err(),
|
||||
"Read should fail when credentials don't exist"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_error() {
|
||||
// This test verifies that serialization errors are properly handled
|
||||
let error = serde_json::from_str::<TestCredentials>("invalid json").unwrap_err();
|
||||
let storage_error = StorageError::SerializationError(error);
|
||||
assert!(matches!(storage_error, StorageError::SerializationError(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_system_error_handling() {
|
||||
// Test handling of file system errors by using an invalid path
|
||||
let invalid_path = String::from("/nonexistent_directory/credentials.json");
|
||||
let manager = CredentialsManager::new(
|
||||
invalid_path,
|
||||
true,
|
||||
"test_service".to_string(),
|
||||
"test_user".to_string(),
|
||||
);
|
||||
|
||||
// Create test credentials
|
||||
let creds = TestCredentials::new();
|
||||
|
||||
// Attempt to write to an invalid path should result in FileSystemError
|
||||
let result = manager.write_to_file(&serde_json::to_string(&creds).unwrap());
|
||||
assert!(matches!(result, Err(StorageError::FileSystemError(_))));
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,10 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
|
||||
|
||||
pub mod computercontroller;
|
||||
mod developer;
|
||||
pub mod google_drive;
|
||||
mod memory;
|
||||
mod tutorial;
|
||||
|
||||
pub use computercontroller::ComputerControllerRouter;
|
||||
pub use developer::DeveloperRouter;
|
||||
pub use google_drive::GoogleDriveRouter;
|
||||
pub use memory::MemoryRouter;
|
||||
pub use tutorial::TutorialRouter;
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
use anyhow::Result;
|
||||
use goose_mcp::{
|
||||
ComputerControllerRouter, DeveloperRouter, GoogleDriveRouter, MemoryRouter, TutorialRouter,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use goose_mcp::{ComputerControllerRouter, DeveloperRouter, MemoryRouter, TutorialRouter};
|
||||
use mcp_server::router::RouterService;
|
||||
use mcp_server::{BoundedService, ByteTransport, Server};
|
||||
use tokio::io::{stdin, stdout};
|
||||
|
||||
pub async fn run(name: &str) -> Result<()> {
|
||||
// Initialize logging
|
||||
crate::logging::setup_logging(Some(&format!("mcp-{name}")))?;
|
||||
|
||||
if name == "googledrive" || name == "google_drive" {
|
||||
return Err(anyhow!(
|
||||
"the built-in Google Drive extension has been removed"
|
||||
));
|
||||
}
|
||||
|
||||
tracing::info!("Starting MCP server");
|
||||
let router: Option<Box<dyn BoundedService>> = match name {
|
||||
"developer" => Some(Box::new(RouterService(DeveloperRouter::new()))),
|
||||
"computercontroller" => Some(Box::new(RouterService(ComputerControllerRouter::new()))),
|
||||
"google_drive" | "googledrive" => {
|
||||
let router = GoogleDriveRouter::new().await;
|
||||
Some(Box::new(RouterService(router)))
|
||||
}
|
||||
"memory" => Some(Box::new(RouterService(MemoryRouter::new()))),
|
||||
"tutorial" => Some(Box::new(RouterService(TutorialRouter::new()))),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Create and run the server
|
||||
let server = Server::new(router.unwrap_or_else(|| panic!("Unknown server requested {}", name)));
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ type BundledExtension = {
|
||||
allow_configure?: boolean;
|
||||
};
|
||||
|
||||
const DEPRECATED_BUILTINS = ['googledrive', 'google_drive'];
|
||||
|
||||
/**
|
||||
* Synchronizes built-in extensions with the config system.
|
||||
* This function ensures all built-in extensions are added, which is especially
|
||||
@@ -37,6 +39,13 @@ export async function syncBundledExtensions(
|
||||
// Cast the imported JSON data to the expected type
|
||||
const bundledExtensions = bundledExtensionsData as BundledExtension[];
|
||||
|
||||
for (let i = existingExtensions.length - 1; i >= 0; i--) {
|
||||
const ext = existingExtensions[i];
|
||||
if (ext.type == 'builtin' && DEPRECATED_BUILTINS.includes(ext.name)) {
|
||||
existingExtensions.splice(i, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Process each bundled extension
|
||||
for (const bundledExt of bundledExtensions) {
|
||||
// Find if this extension already exists
|
||||
|
||||
Reference in New Issue
Block a user