diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4334596c..6d767962 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,15 +74,12 @@ jobs: needs: [test] steps: - uses: actions/checkout@v6 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - override: true - - run: | - pip install psycopg + - uses: cachix/install-nix-action@v31 + - name: Initialize podman + run: nix develop --command podman info >/dev/null 2>&1 || true - run: | cd tests-integration - ./test.sh + nix develop --command bash test.sh msrv: name: MSRV diff --git a/Cargo.lock b/Cargo.lock index 9e5b89d8..d64eccaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -392,7 +392,7 @@ dependencies = [ "geo-traits", "geoarrow", "geoarrow-schema", - "pg_interval_2", + "pg_interval", "pgwire", "postgis", "postgres-types", @@ -1908,7 +1908,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3122,10 +3122,10 @@ dependencies = [ ] [[package]] -name = "pg_interval_2" -version = "0.5.1" +name = "pg_interval" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469827e70c8c74562f88b9434cf8a8fe35665281d2442304e99efcadf8f76a8f" +checksum = "c386dd54fce258fc04e668126ae68589a0d92e03a90ea67881d1300f70fd6170" dependencies = [ "bytes", "chrono", @@ -3134,9 +3134,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.39.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3cee243b682091188b90f22b07585ad6b43e699b52bc29422bd9ca6ce2c2deb" +checksum = "c764131468c2049ee0d9324e1f465d45f9b16d54797cb75d401a60a39429dfa9" dependencies = [ "async-trait", "base64", @@ -3147,7 +3147,7 @@ dependencies = [ "hex", "lazy-regex", "md5", - "pg_interval_2", + "pg_interval", "postgis", "postgres-types", "rand 0.10.0", @@ -3699,7 +3699,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4113,7 +4113,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4624,7 +4624,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 11447a92..8dbf2e31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ bytes = "1.11.1" chrono = { version = "0.4", features = ["std"] } datafusion = { version = "53" } futures = "0.3" -pgwire = { version = "0.39", default-features = false } +pgwire = { version = "0.40", default-features = false } postgres-types = "0.2" rust_decimal = { version = "1.41", features = ["db-postgres"] } tokio = { version = "1", default-features = false } diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 5c30366e..1daf5ca8 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -27,7 +27,7 @@ datafusion = { workspace = true, optional = true } futures.workspace = true geoarrow = { version = "0.8", optional = true } geoarrow-schema = { version = "0.8", optional = true } -pg_interval = { version = "0.5.1", package = "pg_interval_2" } +pg_interval = { version = "0.5.0" } pgwire = { workspace = true, default-features = false, features = ["server-api", "pg-ext-types"] } postgres-types.workspace = true rust_decimal.workspace = true diff --git a/datafusion-pg-catalog/src/sql/parser.rs b/datafusion-pg-catalog/src/sql/parser.rs index 719df904..a918b5b7 100644 --- a/datafusion-pg-catalog/src/sql/parser.rs +++ b/datafusion-pg-catalog/src/sql/parser.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use datafusion::sql::sqlparser::ast::Statement; use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::keywords::Keyword; use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::sqlparser::parser::ParserError; use datafusion::sql::sqlparser::tokenizer::Token; @@ -247,10 +248,19 @@ impl PostgresCompatibilityParser { // Get token values (without spans) and filter out only whitespace // Keep semicolons as they separate statements + // Also rewrite ABORT to ROLLBACK for postgres compatibility + // remove this when https://github.com/apache/datafusion-sqlparser-rs/pull/2332 is ready let filtered_tokens: Vec = tokens .iter() .map(|t| t.token.clone()) .filter(|t| !matches!(t, Token::Whitespace(_))) + .map(|t| { + if matches!(&t, Token::Word(w) if w.keyword == Keyword::ABORT) { + Token::make_keyword("ROLLBACK") + } else { + t + } + }) .collect(); // Handle empty input diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index f774a3b0..c7d53203 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -16,12 +16,16 @@ use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{FieldInfo, Response, Tag}; use pgwire::api::stmt::QueryParser; -use pgwire::api::{ClientInfo, ConnectionManager, ErrorHandler, PgWireServerHandlers, Type}; +use pgwire::api::store::PortalStore; +use pgwire::api::{ + ClientInfo, ClientPortalStore, ConnectionManager, ErrorHandler, PgWireServerHandlers, Type, +}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; use pgwire::types::format::FormatOptions; use crate::hooks::QueryHook; +use crate::hooks::cursor::CursorStatementHook; use crate::hooks::set_show::SetShowHook; use crate::hooks::transactions::TransactionStatementHook; use crate::{client, planner}; @@ -121,8 +125,11 @@ pub struct DfSessionService { impl DfSessionService { pub fn new(session_context: Arc) -> DfSessionService { - let hooks: Vec> = - vec![Arc::new(SetShowHook), Arc::new(TransactionStatementHook)]; + let hooks: Vec> = vec![ + Arc::new(CursorStatementHook), + Arc::new(SetShowHook), + Arc::new(TransactionStatementHook), + ]; Self::new_with_hooks(session_context, hooks) } @@ -147,11 +154,18 @@ impl DfSessionService { impl SimpleQueryHandler for DfSessionService { async fn do_query(&self, client: &mut C, query: &str) -> PgWireResult> where - C: ClientInfo + futures::Sink + Unpin + Send + Sync, + C: ClientInfo + + ClientPortalStore + + futures::Sink + + Unpin + + Send + + Sync, + C::PortalStore: PortalStore, C::Error: std::fmt::Debug, PgWireError: From<>::Error>, { log::debug!("Received query: {query}"); + let statements = self .parser .sql_parser @@ -235,7 +249,13 @@ impl ExtendedQueryHandler for DfSessionService { _max_rows: usize, ) -> PgWireResult where - C: ClientInfo + futures::Sink + Unpin + Send + Sync, + C: ClientInfo + + ClientPortalStore + + futures::Sink + + Unpin + + Send + + Sync, + C::PortalStore: PortalStore, C::Error: std::fmt::Debug, PgWireError: From<>::Error>, { @@ -636,4 +656,246 @@ mod tests { assert!(!has_ps, "statement_timeout should not send ParameterStatus"); } + + fn assert_execution_tag(response: &Response, expected: &str) { + match response { + Response::Execution(tag) => { + let cc = pgwire::messages::response::CommandComplete::from(tag.clone()); + assert_eq!(cc.tag, expected, "Unexpected execution tag"); + } + other => panic!("Expected Execution response, got: {other:?}"), + } + } + + async fn assert_query_response_empty(response: &mut Response) { + use futures::StreamExt; + + let Response::Query(qr) = response else { + panic!("Expected Query response, got: {response:?}"); + }; + + let mut count = 0; + while qr.data_rows().next().await.is_some() { + count += 1; + } + assert_eq!(count, 0, "Expected no rows from exhausted cursor"); + } + + #[tokio::test] + async fn test_declare_fetch_close_cursor() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + let responses = ::do_query( + &service, + &mut client, + "DECLARE test_cursor CURSOR FOR SELECT 1 AS col", + ) + .await + .unwrap(); + + assert_eq!(responses.len(), 1); + assert_execution_tag(&responses[0], "DECLARE CURSOR"); + + let responses = ::do_query( + &service, + &mut client, + "FETCH NEXT FROM test_cursor", + ) + .await + .unwrap(); + + assert_eq!(responses.len(), 1); + assert!( + matches!(&responses[0], Response::Query(_)), + "Expected Query response for FETCH" + ); + + let mut responses = ::do_query( + &service, + &mut client, + "FETCH NEXT FROM test_cursor", + ) + .await + .unwrap(); + + assert_eq!(responses.len(), 1); + assert_query_response_empty(&mut responses[0]).await; + + let responses = ::do_query( + &service, + &mut client, + "CLOSE test_cursor", + ) + .await + .unwrap(); + + assert_eq!(responses.len(), 1); + assert_execution_tag(&responses[0], "CLOSE CURSOR"); + } + + #[tokio::test] + async fn test_fetch_nonexistent_cursor() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + let result = ::do_query( + &service, + &mut client, + "FETCH NEXT FROM nonexistent", + ) + .await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_close_all_portals() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + ::do_query( + &service, + &mut client, + "DECLARE c1 CURSOR FOR SELECT 1", + ) + .await + .unwrap(); + + ::do_query( + &service, + &mut client, + "DECLARE c2 CURSOR FOR SELECT 2", + ) + .await + .unwrap(); + + let responses = + ::do_query(&service, &mut client, "CLOSE ALL") + .await + .unwrap(); + + assert!(matches!(&responses[0], Response::Execution(_)),); + + let result = ::do_query( + &service, + &mut client, + "FETCH NEXT FROM c1", + ) + .await; + assert!(result.is_err(), "c1 should be closed"); + } + + #[tokio::test] + async fn test_fetch_forward_n() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + ::do_query( + &service, + &mut client, + "CREATE TABLE nums AS SELECT 1 AS n UNION ALL SELECT 2 UNION ALL SELECT 3 UNION ALL SELECT 4 UNION ALL SELECT 5", + ) + .await + .unwrap(); + + ::do_query( + &service, + &mut client, + "DECLARE mycur CURSOR FOR SELECT n FROM nums ORDER BY n", + ) + .await + .unwrap(); + + let responses = ::do_query( + &service, + &mut client, + "FETCH FORWARD 3 FROM mycur", + ) + .await + .unwrap(); + + assert!( + matches!(&responses[0], Response::Query(_)), + "Expected Query response for FORWARD 3" + ); + + let responses = ::do_query( + &service, + &mut client, + "FETCH FORWARD ALL FROM mycur", + ) + .await + .unwrap(); + + let resp_desc = match &responses[0] { + Response::Query(_) => "Query".to_string(), + Response::Execution(tag) => { + let cc = pgwire::messages::response::CommandComplete::from(tag.clone()); + format!("Execution({})", cc.tag) + } + other => format!("{:?}", other), + }; + assert!( + matches!(&responses[0], Response::Query(_)), + "Expected Query response for remaining rows, got: {resp_desc}" + ); + + let mut responses = ::do_query( + &service, + &mut client, + "FETCH NEXT FROM mycur", + ) + .await + .unwrap(); + + assert_query_response_empty(&mut responses[0]).await; + } + + #[tokio::test] + async fn test_scroll_cursor_error() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + ::do_query( + &service, + &mut client, + "DECLARE mycur CURSOR FOR SELECT 1", + ) + .await + .unwrap(); + + let result = ::do_query( + &service, + &mut client, + "FETCH PRIOR FROM mycur", + ) + .await; + + assert!(result.is_err(), "PRIOR should fail on forward-only cursor"); + } + + #[tokio::test] + async fn test_move_cursor() { + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + ::do_query( + &service, + &mut client, + "DECLARE mycur CURSOR FOR SELECT generate_series(1, 5) AS n", + ) + .await + .unwrap(); + + let responses = ::do_query( + &service, + &mut client, + "FETCH FORWARD 3 FROM mycur", + ) + .await + .unwrap(); + + assert!(matches!(&responses[0], Response::Query(_))); + } } diff --git a/datafusion-postgres/src/hooks/cursor.rs b/datafusion-postgres/src/hooks/cursor.rs new file mode 100644 index 00000000..2538ea53 --- /dev/null +++ b/datafusion-postgres/src/hooks/cursor.rs @@ -0,0 +1,223 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::common::ParamValues; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::ast::{CloseCursor, DeclareType, FetchDirection}; +use pgwire::api::ClientInfo; +use pgwire::api::portal::{Format, Portal}; +use pgwire::api::results::{Response, Tag}; +use pgwire::api::stmt::StoredStatement; +use pgwire::api::store::{MemPortalStore, PortalStore}; +use pgwire::error::{PgWireError, PgWireResult}; + +use super::{HookClient, QueryHook}; +use crate::arrow_pg::datatypes::df; + +pub(crate) type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>); + +/// Hook for processing cursor-related statements (DECLARE/FETCH/CLOSE) +#[derive(Debug)] +pub struct CursorStatementHook; + +#[async_trait] +impl QueryHook for CursorStatementHook { + async fn handle_simple_query( + &self, + statement: &sqlparser::ast::Statement, + session_context: &SessionContext, + client: &mut dyn HookClient, + ) -> Option> { + let store = client.portal_store(); + + match statement { + sqlparser::ast::Statement::Declare { stmts } => { + Some(handle_declare(store, stmts, session_context).await) + } + sqlparser::ast::Statement::Fetch { + name, direction, .. + } => Some(handle_fetch(store, name, direction).await), + sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)), + _ => None, + } + } + + async fn handle_extended_parse_query( + &self, + statement: &sqlparser::ast::Statement, + _session_context: &SessionContext, + _client: &(dyn ClientInfo + Send + Sync), + ) -> Option> { + match statement { + sqlparser::ast::Statement::Declare { .. } + | sqlparser::ast::Statement::Fetch { .. } + | sqlparser::ast::Statement::Close { .. } => Some(Ok(LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: false, + schema: Arc::new(datafusion::common::DFSchema::empty()), + }, + ))), + _ => None, + } + } + + async fn handle_extended_query( + &self, + statement: &sqlparser::ast::Statement, + _logical_plan: &LogicalPlan, + _params: &ParamValues, + session_context: &SessionContext, + client: &mut dyn HookClient, + ) -> Option> { + let store = client.portal_store(); + + match statement { + sqlparser::ast::Statement::Declare { stmts } => { + Some(handle_declare(store, stmts, session_context).await) + } + sqlparser::ast::Statement::Fetch { + name, direction, .. + } => Some(handle_fetch(store, name, direction).await), + sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)), + _ => None, + } + } +} + +async fn handle_declare( + store: &MemPortalStore, + stmts: &[datafusion::sql::sqlparser::ast::Declare], + session_context: &SessionContext, +) -> PgWireResult { + for declare in stmts { + if declare.declare_type != Some(DeclareType::Cursor) { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + format!("unsupported DECLARE type: {:?}", declare.declare_type), + ), + ))); + } + + let cursor_name = match declare.names.first() { + Some(name) => name.value.clone(), + None => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "cursor name is required".to_string(), + ), + ))); + } + }; + + let for_query = match &declare.for_query { + Some(q) => q.to_string(), + None => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "DECLARE CURSOR requires a FOR query".to_string(), + ), + ))); + } + }; + + let df = session_context + .sql(&for_query) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let query_response = df::encode_dataframe(df, &Format::UnifiedText, None).await?; + + let stored_stmt = Arc::new(StoredStatement::new( + cursor_name.clone(), + (for_query, None), + vec![], + )); + + let portal = Portal::new_cursor(cursor_name.clone(), stored_stmt); + + portal.start(query_response).await; + + store.put_portal(Arc::new(portal)); + } + + Ok(Response::Execution(Tag::new("DECLARE CURSOR"))) +} + +async fn handle_fetch( + store: &MemPortalStore, + name: &datafusion::sql::sqlparser::ast::Ident, + direction: &FetchDirection, +) -> PgWireResult { + let cursor_name = &name.value; + + let max_rows = match direction { + FetchDirection::Next | FetchDirection::Forward { limit: None } => Some(1), + FetchDirection::Forward { limit: Some(v) } | FetchDirection::Count { limit: v } => { + parse_value_as_usize(v) + } + FetchDirection::ForwardAll | FetchDirection::All => None, + FetchDirection::Prior | FetchDirection::Backward { .. } | FetchDirection::BackwardAll => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42000".to_string(), + "cursor can only scan forward".to_string(), + ), + ))); + } + FetchDirection::First + | FetchDirection::Last + | FetchDirection::Absolute { .. } + | FetchDirection::Relative { .. } => { + return Err(PgWireError::UserError(Box::new( + pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "42000".to_string(), + "cursor can only scan forward".to_string(), + ), + ))); + } + }; + + let portal = store.get_portal(cursor_name).ok_or_else(|| { + PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new( + "ERROR".to_string(), + "34000".to_string(), + format!("cursor \"{cursor_name}\" does not exist"), + ))) + })?; + + let fetch_result = portal.fetch(max_rows.unwrap_or(0)).await?; + + Ok(Response::Query(fetch_result.response)) +} + +fn handle_close( + store: &MemPortalStore, + cursor: &CloseCursor, +) -> PgWireResult { + match cursor { + CloseCursor::All => { + store.clear_portals(); + } + CloseCursor::Specific { name } => { + store.rm_portal(&name.value); + } + } + Ok(Response::Execution(Tag::new("CLOSE CURSOR"))) +} + +fn parse_value_as_usize(value: &datafusion::sql::sqlparser::ast::Value) -> Option { + match value { + datafusion::sql::sqlparser::ast::Value::Number(s, _) => s.parse().ok(), + _ => None, + } +} diff --git a/datafusion-postgres/src/hooks/mod.rs b/datafusion-postgres/src/hooks/mod.rs index 48d5c969..c43b5b49 100644 --- a/datafusion-postgres/src/hooks/mod.rs +++ b/datafusion-postgres/src/hooks/mod.rs @@ -1,3 +1,4 @@ +pub mod cursor; pub mod permissions; pub mod set_show; pub mod transactions; @@ -10,21 +11,35 @@ use datafusion::prelude::SessionContext; use datafusion::sql::sqlparser::ast::Statement; use futures::Sink; use pgwire::api::ClientInfo; +use pgwire::api::ClientPortalStore; use pgwire::api::results::Response; +use pgwire::api::store::{MemPortalStore, PortalStore}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; +use crate::hooks::cursor::DfStatement; + #[async_trait] pub trait HookClient: ClientInfo + Send + Sync { + fn portal_store(&self) -> &MemPortalStore; + async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()>; } #[async_trait] impl HookClient for S where - S: ClientInfo + Sink + Send + Sync + Unpin, + S: ClientInfo + ClientPortalStore + Sink + Send + Sync + Unpin, PgWireError: From<>::Error>, + S::PortalStore: PortalStore, { + fn portal_store(&self) -> &MemPortalStore { + self.portal_store() + .as_any() + .downcast_ref::>() + .expect("portal store is not MemPortalStore") + } + async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()> { use futures::SinkExt; self.send(item).await.map_err(PgWireError::from) diff --git a/datafusion-postgres/src/testing.rs b/datafusion-postgres/src/testing.rs index a1a322ed..8fb25e91 100644 --- a/datafusion-postgres/src/testing.rs +++ b/datafusion-postgres/src/testing.rs @@ -1,6 +1,8 @@ use std::{collections::HashMap, sync::Arc}; +use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::sql::sqlparser; use datafusion_pg_catalog::pg_catalog::setup_pg_catalog; use futures::Sink; use pgwire::{ @@ -29,10 +31,12 @@ pub fn setup_handlers() -> DfSessionService { DfSessionService::new(Arc::new(session_context)) } +type DfStatement = (String, Option<(sqlparser::ast::Statement, LogicalPlan)>); + #[derive(Debug, Default)] pub struct MockClient { metadata: HashMap, - portal_store: MemPortalStore, + portal_store: MemPortalStore, pub sent_messages: Vec, session_extensions: SessionExtensions, } @@ -110,7 +114,7 @@ impl ClientInfo for MockClient { } impl ClientPortalStore for MockClient { - type PortalStore = MemPortalStore; + type PortalStore = MemPortalStore; fn portal_store(&self) -> &Self::PortalStore { &self.portal_store } diff --git a/flake.nix b/flake.nix index 4591903b..fe3516b4 100644 --- a/flake.nix +++ b/flake.nix @@ -43,6 +43,8 @@ cargo-nextest cargo-release curl + lsof + podman pythonEnv postgresql_18.out ]; diff --git a/tests-integration/test.sh b/tests-integration/test.sh index 044ca7eb..b5b276aa 100755 --- a/tests-integration/test.sh +++ b/tests-integration/test.sh @@ -5,11 +5,14 @@ set -e # Function to cleanup processes cleanup() { echo "๐Ÿงน Cleaning up processes..." - for pid in $CSV_PID $TRANSACTION_PID $PARQUET_PID $RBAC_PID $SSL_PID $POSTGIS_PID; do + for pid in $CSV_PID $TRANSACTION_PID $PARQUET_PID $RBAC_PID $SSL_PID $POSTGIS_PID $FDW_PID; do if [ ! -z "$pid" ]; then kill -9 $pid 2>/dev/null || true fi done + if [ ! -z "$FDW_PG_CONTAINER" ]; then + podman rm -f $FDW_PG_CONTAINER 2>/dev/null || true + fi } # Trap to cleanup on exit @@ -41,19 +44,6 @@ cd .. cargo build --features datafusion-postgres/postgis cd tests-integration -# Set up test environment - -# Create virtual environment if it doesn't exist -if [ ! -d "test_env" ]; then - echo "Creating Python virtual environment..." - python3 -m venv test_env -fi - -# Activate virtual environment and install dependencies -echo "Setting up Python dependencies..." -source test_env/bin/activate -pip install -q psycopg - # Test 1: CSV data loading and PostgreSQL compatibility echo "" echo "๐Ÿ“Š Test 1: Enhanced CSV Data Loading & PostgreSQL Compatibility" @@ -69,7 +59,7 @@ if ! ps -p $CSV_PID > /dev/null 2>&1; then exit 1 fi -if python3 test_csv.py; then +if python test_csv.py; then echo "โœ… Enhanced CSV test passed" else echo "โŒ Enhanced CSV test failed" @@ -80,16 +70,85 @@ fi kill -9 $CSV_PID 2>/dev/null || true sleep 3 -# Test 2: Transaction support +# Test 2: Foreign Data Wrapper (postgres_fdw) +echo "" +echo "๐ŸŒ Test 2: Foreign Data Wrapper (postgres_fdw)" +echo "-----------------------------------------------" + +# Start a PostgreSQL container for the FDW test +echo "Starting PostgreSQL container..." +FDW_PG_CONTAINER=$(podman run -d \ + -e POSTGRES_USER=postgres \ + -e POSTGRES_DB=fdw_test \ + -e POSTGRES_HOST_AUTH_METHOD=trust \ + -p 5435:5432 \ + docker.io/library/postgres:17) + +if [ -z "$FDW_PG_CONTAINER" ]; then + echo "โš ๏ธ Could not start PostgreSQL container, skipping FDW test" +else + echo "Waiting for PostgreSQL container to be ready..." + timeout=60 + count=0 + until pg_isready -h 127.0.0.1 -p 5435 -q 2>/dev/null; do + if [ $count -ge $timeout ]; then + echo "โŒ PostgreSQL container did not become ready within ${timeout}s" + echo "--- podman logs ---" + podman logs $FDW_PG_CONTAINER 2>&1 || true + echo "--- end logs ---" + podman rm -f $FDW_PG_CONTAINER 2>/dev/null || true + FDW_PG_CONTAINER="" + exit 1 + fi + sleep 1 + count=$((count + 1)) + done + echo " PostgreSQL container is ready (waited ${count}s)" + + # Start datafusion-postgres with CSV data for FDW target + wait_for_port 5433 + ../target/debug/datafusion-postgres-cli --host 0.0.0.0 -p 5433 --csv delhi:delhiclimate.csv & + FDW_PID=$! + sleep 5 + + if ! ps -p $FDW_PID > /dev/null 2>&1; then + echo "โŒ DataFusion server for FDW test failed to start" + podman rm -f $FDW_PG_CONTAINER 2>/dev/null || true + exit 1 + fi + + # Run FDW test + export PGHOST=127.0.0.1 + export PGPORT=5435 + export PGUSER=postgres + export PGDATABASE=fdw_test + export DF_PORT=5433 + + if python test_fdw.py; then + echo "โœ… FDW test passed" + else + echo "โŒ FDW test failed" + kill -9 $FDW_PID 2>/dev/null || true + podman rm -f $FDW_PG_CONTAINER 2>/dev/null || true + exit 1 + fi + + kill -9 $FDW_PID 2>/dev/null || true + podman rm -f $FDW_PG_CONTAINER 2>/dev/null || true + FDW_PG_CONTAINER="" + sleep 3 +fi + +# Test 3: Transaction support echo "" -echo "๐Ÿ” Test 2: Transaction Support" +echo "๐Ÿ” Test 3: Transaction Support" echo "------------------------------" wait_for_port 5433 ../target/debug/datafusion-postgres-cli -p 5433 --csv delhi:delhiclimate.csv & TRANSACTION_PID=$! sleep 5 -if python3 test_transactions.py; then +if python test_transactions.py; then echo "โœ… Transaction test passed" else echo "โŒ Transaction test failed" @@ -100,16 +159,16 @@ fi kill -9 $TRANSACTION_PID 2>/dev/null || true sleep 3 -# Test 3: Parquet data loading and advanced data types +# Test 4: Parquet data loading and advanced data types echo "" -echo "๐Ÿ“ฆ Test 3: Enhanced Parquet Data Loading & Advanced Data Types" +echo "๐Ÿ“ฆ Test 4: Enhanced Parquet Data Loading & Advanced Data Types" echo "--------------------------------------------------------------" wait_for_port 5434 ../target/debug/datafusion-postgres-cli -p 5434 --parquet all_types:all_types.parquet & PARQUET_PID=$! sleep 5 -if python3 test_parquet.py; then +if python test_parquet.py; then echo "โœ… Enhanced Parquet test passed" else echo "โŒ Enhanced Parquet test failed" @@ -135,7 +194,7 @@ if ! ps -p $SSL_PID > /dev/null 2>&1; then exit 1 fi -if python3 test_ssl.py; then +if python test_ssl.py; then echo "โœ… SSL/TLS test passed" else echo "โŒ SSL/TLS test failed" @@ -161,7 +220,7 @@ if ! ps -p $POSTGIS_PID > /dev/null 2>&1; then exit 1 fi -if python3 test_postgis.py; then +if python test_postgis.py; then echo "โœ… PostGIS test passed" else echo "โŒ PostGIS test failed" @@ -177,6 +236,7 @@ echo "==========================================" echo "" echo "๐Ÿ“ˆ Test Summary:" echo " โœ… Enhanced CSV data loading with PostgreSQL compatibility" +echo " โœ… Foreign Data Wrapper (postgres_fdw) support" echo " โœ… Complete transaction support (BEGIN/COMMIT/ROLLBACK)" echo " โœ… Enhanced Parquet data loading with advanced data types" echo " โœ… Array types and complex data type support" diff --git a/tests-integration/test_fdw.py b/tests-integration/test_fdw.py new file mode 100644 index 00000000..3ddd8ef5 --- /dev/null +++ b/tests-integration/test_fdw.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Test postgres_fdw foreign data wrapper support. + +Requires a running PostgreSQL instance accessible via PGHOST/PGPORT/PGUSER/PGDATABASE env vars, +or defaults to localhost:5432 with user postgres and database fdw_test. + +The datafusion-postgres server should be running on port specified by DF_PORT (default 5433). +""" + +import os +import sys +import psycopg + + +DF_PORT = os.environ.get("DF_PORT", "5433") +PG_HOST = os.environ.get("PGHOST", "127.0.0.1") +PG_PORT = os.environ.get("PGPORT", "5432") +PG_USER = os.environ.get("PGUSER", "postgres") +PG_DB = os.environ.get("PGDATABASE", "fdw_test") + + +def main(): + print("๐ŸŒ Testing Foreign Data Wrapper (postgres_fdw)") + print("=" * 50) + + try: + conn = psycopg.connect( + f"host={PG_HOST} port={PG_PORT} user={PG_USER} dbname={PG_DB}", + autocommit=True, + ) + pg_version = conn.info.server_version + print(f" Connected to PostgreSQL {pg_version // 10000}.{pg_version % 10000 // 100}") + + setup_fdw(conn) + test_basic_query(conn) + test_aggregate_query(conn) + test_multiple_rows(conn) + test_order_by(conn) + test_cursor_lifecycle(conn) + cleanup_fdw(conn) + + conn.close() + print("\nโœ… All FDW tests passed!") + return 0 + + except Exception as e: + print(f"\nโŒ FDW tests failed: {e}") + import traceback + traceback.print_exc() + return 1 + + +def setup_fdw(conn): + """Set up the foreign data wrapper connecting to datafusion-postgres.""" + with conn.cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS postgres_fdw") + print(" โœ“ postgres_fdw extension installed") + + cur.execute("DROP SERVER IF EXISTS df_server CASCADE") + cur.execute(f""" + CREATE SERVER df_server + FOREIGN DATA WRAPPER postgres_fdw + OPTIONS (host 'host.containers.internal', port '{DF_PORT}', dbname 'postgres') + """) + print(" โœ“ Foreign server df_server created") + + cur.execute(f""" + CREATE USER MAPPING FOR current_user + SERVER df_server + OPTIONS (user 'postgres', password '') + """) + print(" โœ“ User mapping created") + + cur.execute(""" + IMPORT FOREIGN SCHEMA public + LIMIT TO (delhi) + FROM SERVER df_server + INTO public + """) + print(" โœ“ Foreign table delhi imported") + + +def test_basic_query(conn): + """Test basic SELECT through FDW.""" + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM delhi") + result = cur.fetchone()[0] + assert result > 0, f"Expected rows in delhi, got {result}" + print(f" โœ“ Basic query: {result} rows in delhi") + + +def test_aggregate_query(conn): + """Test aggregate functions through FDW.""" + with conn.cursor() as cur: + cur.execute("SELECT avg(meantemp), max(humidity) FROM delhi") + row = cur.fetchone() + assert row[0] is not None, "Expected non-null avg(meantemp)" + assert row[1] is not None, "Expected non-null max(humidity)" + print(f" โœ“ Aggregate query: avg(meantemp)={row[0]:.2f}, max(humidity)={row[1]}") + + +def test_multiple_rows(conn): + """Test fetching multiple rows through FDW.""" + with conn.cursor() as cur: + cur.execute("SELECT date, meantemp FROM delhi ORDER BY date LIMIT 5") + rows = cur.fetchall() + assert len(rows) == 5, f"Expected 5 rows, got {len(rows)}" + print(f" โœ“ Multiple rows: fetched {len(rows)} rows") + + +def test_order_by(conn): + """Test ORDER BY through FDW.""" + with conn.cursor() as cur: + cur.execute("SELECT date, meantemp FROM delhi ORDER BY meantemp DESC LIMIT 3") + rows = cur.fetchall() + temps = [row[1] for row in rows] + assert temps == sorted(temps, reverse=True), "Expected descending order" + print(f" โœ“ ORDER BY: top 3 temps = {temps}") + + +def test_cursor_lifecycle(conn): + """Test DECLARE/FETCH/CLOSE cursor through FDW.""" + with conn.cursor() as cur: + cur.execute("BEGIN") + cur.execute("DECLARE fdw_cur CURSOR FOR SELECT date, meantemp FROM delhi ORDER BY date") + print(" โœ“ DECLARE CURSOR") + + cur.execute("FETCH FORWARD 3 FROM fdw_cur") + rows = cur.fetchall() + assert len(rows) == 3, f"Expected 3 rows from FETCH, got {len(rows)}" + print(f" โœ“ FETCH FORWARD 3: got {len(rows)} rows") + + cur.execute("FETCH NEXT FROM fdw_cur") + row = cur.fetchone() + assert row is not None, "Expected a row from FETCH NEXT" + print(" โœ“ FETCH NEXT: got 1 row") + + cur.execute("CLOSE fdw_cur") + cur.execute("COMMIT") + print(" โœ“ CLOSE + COMMIT") + + +def cleanup_fdw(conn): + """Clean up FDW objects.""" + with conn.cursor() as cur: + cur.execute("DROP FOREIGN TABLE IF EXISTS delhi CASCADE") + cur.execute("DROP USER MAPPING IF EXISTS FOR current_user SERVER df_server") + cur.execute("DROP SERVER IF EXISTS df_server CASCADE") + print(" โœ“ FDW objects cleaned up") + + +if __name__ == "__main__": + sys.exit(main())