diff --git a/Cargo.toml b/Cargo.toml index bf59417..bb87e10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,6 @@ repository = "https://github.com/kitsuyui/sxd_html_table" [dependencies] csv = "1.1.6" +sxd-document = "0.3.2" sxd-xpath = "0.4.2" sxd_html = "0.1.1" diff --git a/src/lib.rs b/src/lib.rs index ea676bb..94e6a08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,7 +68,7 @@ impl Table { } } -pub fn extract_tables_from_document(html: &str) -> Result, Error> { +pub fn extract_table_texts_from_document(html: &str) -> Result, Error> { let package = sxd_html::parse_html(html); let document = package.as_document(); #[allow(clippy::expect_used)] @@ -79,7 +79,26 @@ pub fn extract_tables_from_document(html: &str) -> Result, Error> { }; let mut tables = vec![]; for node in table_nodes.document_order() { - match extract_table(&node) { + match extract_table_texts(&node) { + Ok(table) => tables.push(table), + Err(e) => return Err(e), + } + } + Ok(tables) +} + +pub fn extract_table_elements_from_document(html: &str) -> Result, Error> { + let package = sxd_html::parse_html(html); + let document = package.as_document(); + #[allow(clippy::expect_used)] + let val = evaluate_xpath_node(document.root(), "//table").expect("XPath evaluation failed"); + + let Value::Nodeset(table_nodes) = val else { + panic!("Expected node set"); + }; + let mut tables = vec![]; + for node in table_nodes.document_order() { + match extract_table_elements(&node) { Ok(table) => tables.push(table), Err(e) => return Err(e), } @@ -90,26 +109,25 @@ pub fn extract_tables_from_document(html: &str) -> Result, Error> { fn extract_rowspan_and_colspan(node: &Node) -> (usize, usize) { #[allow(clippy::expect_used)] let element = node.element().expect("Expected element"); - let rowspan = element - .attribute_value("rowspan") - .unwrap_or("1") - .parse::() - .unwrap_or(1); - let colspan = element - .attribute_value("colspan") + let rowspan = extract_span(element, "rowspan"); + let colspan = extract_span(element, "colspan"); + (rowspan, colspan) +} + +fn extract_span(element: sxd_document::dom::Element, name: &str) -> usize { + element + .attribute_value(name) .unwrap_or("1") .parse::() - .unwrap_or(1); - (rowspan, colspan) + .unwrap_or(1) } -fn extract_table(node: &Node) -> Result { +pub fn map_table_cell(node: &Node, f: fn(&Node) -> String) -> Result { let tr_nodes = match evaluate_xpath_node(*node, "./tbody/tr") { Ok(Value::Nodeset(tr_nodes)) => tr_nodes, _ => return Err(Error::InvalidDocument), }; let tr_nodes = tr_nodes.document_order(); - let mut map: HashMap<(usize, usize), String> = HashMap::new(); let mut header_map: HashMap<(usize, usize), bool> = HashMap::new(); for (row_index, tr) in tr_nodes.iter().enumerate() { @@ -121,7 +139,7 @@ fn extract_table(node: &Node) -> Result { let mut col_index = 0; for (_, cell_node) in cell_nodes.iter().enumerate() { let (row_size, col_size) = extract_rowspan_and_colspan(cell_node); - let text = &cell_node.string_value(); + let text = f(cell_node); #[allow(clippy::expect_used)] let is_header = cell_node.element().expect("Expected element").name() == "th".into(); while map.contains_key(&(row_index, col_index)) { @@ -135,16 +153,42 @@ fn extract_table(node: &Node) -> Result { } } } + Ok(map_to_table(&map)) +} + +fn map_to_table(map: &HashMap<(usize, usize), String>) -> Table { let rows = map.keys().map(|(i, _)| i).max().unwrap_or(&0) + 1; let cols = map.keys().map(|(_, j)| j).max().unwrap_or(&0) + 1; let mut table = Table::new((rows, cols)); for ((i, j), text) in map { - table.cells[i * table.size.1 + j] = Some(text); + table.cells[i * table.size.1 + j] = Some(text.to_string()); } - for ((i, j), is_header) in header_map { - table.headers[i * table.size.1 + j] = is_header; + table +} + +fn extract_table_texts(node: &Node) -> Result { + map_table_cell(node, |node| node.string_value()) +} + +fn extract_table_elements(node: &Node) -> Result { + map_table_cell(node, |node| element_to_html(node)) +} + +fn element_to_html(node: &Node) -> String { + let mut buf = Vec::new(); + let package = sxd_document::Package::new(); + let doc = package.as_document(); + let root = doc.root(); + match node.element() { + Some(element) => { + root.append_child(element.clone()); + } + None => (), } - Ok(table) + #[allow(clippy::expect_used)] + sxd_document::writer::format_document(&doc, &mut buf).expect("Failed to format document"); + #[allow(clippy::expect_used)] + String::from_utf8(buf).expect("Failed to convert to UTF-8") } fn evaluate_xpath_node<'d>( @@ -179,7 +223,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 1); assert_eq!(result[0].to_csv().unwrap(), "1,2\n"); @@ -202,7 +246,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 2); assert_eq!(result[0].to_csv().unwrap(), "1,2\n",); assert_eq!(result[1].to_csv().unwrap(), "3,4\n",); @@ -218,12 +262,12 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 0); // empty html let html = r#""#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 0); } @@ -245,7 +289,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 1); assert_eq!(result[0].to_csv().unwrap(), "1,2\n3,4\n"); } @@ -267,7 +311,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 1); assert_eq!(result[0].to_csv().unwrap(), "A,B\nA,C\n"); @@ -288,7 +332,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 1); assert_eq!(result[0].to_csv().unwrap(), "A,A,B\nC,D,E\n"); @@ -333,7 +377,7 @@ mod tests { "#; - let result = extract_tables_from_document(html).unwrap(); + let result = extract_table_texts_from_document(html).unwrap(); assert_eq!(result.len(), 6); assert_eq!(result[0].to_csv().unwrap(), "A,A,B\nA,A,C\n"); assert_eq!(result[1].to_csv().unwrap(), "a,b,c\nd,e,f\n");