1use std::sync::Arc;
15
16use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, StructArray, UInt8Array};
17use arrow::datatypes::{Field, Schema};
18use arrow::error::ArrowError;
19use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema, to_ffi};
20use arrow::record_batch::RecordBatch;
21use powerio::{BusId, Network};
22
23pub const PIO_ARROW_TABLE_BUS: i32 = 0;
26pub const PIO_ARROW_TABLE_BRANCH: i32 = 1;
27pub const PIO_ARROW_TABLE_GEN: i32 = 2;
28pub const PIO_ARROW_TABLE_LOAD: i32 = 3;
29pub const PIO_ARROW_TABLE_SHUNT: i32 = 4;
30
31const _: () = assert!(
35 PIO_ARROW_TABLE_BUS == 0
36 && PIO_ARROW_TABLE_BRANCH == 1
37 && PIO_ARROW_TABLE_GEN == 2
38 && PIO_ARROW_TABLE_LOAD == 3
39 && PIO_ARROW_TABLE_SHUNT == 4
40);
41
42pub fn export(net: &Network, table: i32) -> Result<(FFI_ArrowArray, FFI_ArrowSchema), String> {
46 let rb = match table {
47 PIO_ARROW_TABLE_BUS => bus_batch(net),
48 PIO_ARROW_TABLE_BRANCH => branch_batch(net),
49 PIO_ARROW_TABLE_GEN => gen_batch(net),
50 PIO_ARROW_TABLE_LOAD => load_batch(net),
51 PIO_ARROW_TABLE_SHUNT => shunt_batch(net),
52 other => return Err(format!("unknown Arrow table id {other}")),
53 }
54 .map_err(|e| e.to_string())?;
55
56 let data = StructArray::from(rb).into_data();
58 to_ffi(&data).map_err(|e| e.to_string())
59}
60
61fn bus_batch(net: &Network) -> Result<RecordBatch, ArrowError> {
62 let b = &net.buses;
63 batch(vec![
64 ("id", i64s(b.iter().map(|x| ext(x.id)).collect())),
65 (
66 "kind",
67 i64s(b.iter().map(|x| i64::from(x.kind as u8)).collect()),
68 ),
69 ("vm", f64s(b.iter().map(|x| x.vm).collect())),
70 ("va", f64s(b.iter().map(|x| x.va).collect())),
71 ("base_kv", f64s(b.iter().map(|x| x.base_kv).collect())),
72 ("vmax", f64s(b.iter().map(|x| x.vmax).collect())),
73 ("vmin", f64s(b.iter().map(|x| x.vmin).collect())),
74 ("area", i64s(b.iter().map(|x| usz(x.area)).collect())),
75 ("zone", i64s(b.iter().map(|x| usz(x.zone)).collect())),
76 ])
77}
78
79fn branch_batch(net: &Network) -> Result<RecordBatch, ArrowError> {
80 let br = &net.branches;
81 batch(vec![
82 ("from", i64s(br.iter().map(|x| ext(x.from)).collect())),
83 ("to", i64s(br.iter().map(|x| ext(x.to)).collect())),
84 ("r", f64s(br.iter().map(|x| x.r).collect())),
85 ("x", f64s(br.iter().map(|x| x.x).collect())),
86 ("b", f64s(br.iter().map(|x| x.b).collect())),
87 ("rate_a", f64s(br.iter().map(|x| x.rate_a).collect())),
88 ("rate_b", f64s(br.iter().map(|x| x.rate_b).collect())),
89 ("rate_c", f64s(br.iter().map(|x| x.rate_c).collect())),
90 ("tap", f64s(br.iter().map(|x| x.tap).collect())),
91 ("shift", f64s(br.iter().map(|x| x.shift).collect())),
92 (
93 "in_service",
94 u8s(br.iter().map(|x| u8::from(x.in_service)).collect()),
95 ),
96 ("angmin", f64s(br.iter().map(|x| x.angmin).collect())),
97 ("angmax", f64s(br.iter().map(|x| x.angmax).collect())),
98 ])
99}
100
101fn gen_batch(net: &Network) -> Result<RecordBatch, ArrowError> {
102 let g = &net.generators;
103 batch(vec![
104 ("bus", i64s(g.iter().map(|x| ext(x.bus)).collect())),
105 ("pg", f64s(g.iter().map(|x| x.pg).collect())),
106 ("qg", f64s(g.iter().map(|x| x.qg).collect())),
107 ("pmax", f64s(g.iter().map(|x| x.pmax).collect())),
108 ("pmin", f64s(g.iter().map(|x| x.pmin).collect())),
109 ("qmax", f64s(g.iter().map(|x| x.qmax).collect())),
110 ("qmin", f64s(g.iter().map(|x| x.qmin).collect())),
111 ("vg", f64s(g.iter().map(|x| x.vg).collect())),
112 ("mbase", f64s(g.iter().map(|x| x.mbase).collect())),
113 (
114 "in_service",
115 u8s(g.iter().map(|x| u8::from(x.in_service)).collect()),
116 ),
117 ])
118}
119
120fn load_batch(net: &Network) -> Result<RecordBatch, ArrowError> {
121 let l = &net.loads;
122 batch(vec![
123 ("bus", i64s(l.iter().map(|x| ext(x.bus)).collect())),
124 ("p", f64s(l.iter().map(|x| x.p).collect())),
125 ("q", f64s(l.iter().map(|x| x.q).collect())),
126 (
127 "in_service",
128 u8s(l.iter().map(|x| u8::from(x.in_service)).collect()),
129 ),
130 ])
131}
132
133fn shunt_batch(net: &Network) -> Result<RecordBatch, ArrowError> {
134 let s = &net.shunts;
135 batch(vec![
136 ("bus", i64s(s.iter().map(|x| ext(x.bus)).collect())),
137 ("g", f64s(s.iter().map(|x| x.g).collect())),
138 ("b", f64s(s.iter().map(|x| x.b).collect())),
139 (
140 "in_service",
141 u8s(s.iter().map(|x| u8::from(x.in_service)).collect()),
142 ),
143 ])
144}
145
146fn batch(cols: Vec<(&str, ArrayRef)>) -> Result<RecordBatch, ArrowError> {
147 let fields: Vec<Field> = cols
148 .iter()
149 .map(|(name, arr)| Field::new(*name, arr.data_type().clone(), false))
150 .collect();
151 let arrays: Vec<ArrayRef> = cols.into_iter().map(|(_, arr)| arr).collect();
152 RecordBatch::try_new(Arc::new(Schema::new(fields)), arrays)
153}
154
155fn ext(id: BusId) -> i64 {
157 i64::try_from(id.0).unwrap_or(-1)
158}
159
160fn usz(n: usize) -> i64 {
161 i64::try_from(n).unwrap_or(-1)
162}
163
164fn i64s(v: Vec<i64>) -> ArrayRef {
165 Arc::new(Int64Array::from(v))
166}
167
168fn f64s(v: Vec<f64>) -> ArrayRef {
169 Arc::new(Float64Array::from(v))
170}
171
172fn u8s(v: Vec<u8>) -> ArrayRef {
173 Arc::new(UInt8Array::from(v))
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use arrow::ffi::from_ffi;
180
181 fn net(name: &str) -> Network {
182 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
183 .join("../tests/data")
184 .join(name);
185 powerio::parse_file(&path, None).unwrap().network
186 }
187
188 fn round_trip(net: &Network, table: i32) -> StructArray {
189 let (array, schema) = export(net, table).unwrap();
190 let data = unsafe { from_ffi(array, &schema) }.unwrap();
192 StructArray::from(data)
193 }
194
195 #[test]
196 fn bus_table_round_trips_with_external_ids() {
197 let n = net("case9.m");
198 let sa = round_trip(&n, PIO_ARROW_TABLE_BUS);
199 assert_eq!(sa.len(), n.buses.len());
200 let ids = sa
201 .column_by_name("id")
202 .unwrap()
203 .as_any()
204 .downcast_ref::<Int64Array>()
205 .unwrap();
206 let expected: Vec<i64> = n
209 .buses
210 .iter()
211 .map(|b| i64::try_from(b.id.0).unwrap())
212 .collect();
213 assert_eq!(ids.values(), expected.as_slice());
214 }
215
216 #[test]
217 fn empty_table_exports_zero_rows() {
218 let n = net("case9.m");
221 assert_eq!(n.shunts.len(), 0);
222 assert_eq!(round_trip(&n, PIO_ARROW_TABLE_SHUNT).len(), 0);
223 }
224
225 #[test]
226 fn every_table_has_the_expected_row_count() {
227 let n = net("case30.m");
229 assert_eq!(round_trip(&n, PIO_ARROW_TABLE_BUS).len(), n.buses.len());
230 assert_eq!(
231 round_trip(&n, PIO_ARROW_TABLE_BRANCH).len(),
232 n.branches.len()
233 );
234 assert_eq!(
235 round_trip(&n, PIO_ARROW_TABLE_GEN).len(),
236 n.generators.len()
237 );
238 assert_eq!(round_trip(&n, PIO_ARROW_TABLE_LOAD).len(), n.loads.len());
239 assert_eq!(round_trip(&n, PIO_ARROW_TABLE_SHUNT).len(), n.shunts.len());
240 }
241
242 #[test]
243 fn unknown_table_id_errors() {
244 let n = net("case9.m");
245 assert!(export(&n, 99).is_err());
246 }
247}