From b268b295e350570661afdfc0c6727f28909b8207 Mon Sep 17 00:00:00 2001 From: Anson Date: Tue, 15 Mar 2022 21:20:49 -0700 Subject: [PATCH] added legend, custom columns --- src/SplomPlots.jl | 63 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/src/SplomPlots.jl b/src/SplomPlots.jl index adebfc0..21ccf02 100644 --- a/src/SplomPlots.jl +++ b/src/SplomPlots.jl @@ -4,18 +4,23 @@ using DataFrames using Plots using RDatasets -function splom(df::DataFrame) - columns = Dict(names(df) .=> eltype.(eachcol(df))) - - for (key, value) in columns - if value <: Number - continue - - else - pop!(columns, key) - end +function splom(df::DataFrame; group="", columns=[]) + if columns == [] + columns = Dict(names(df) .=> eltype.(eachcol(df))) + else + columns = Dict(names(df[!,columns]) .=> eltype.(eachcol(df[!,columns]))) end + for (key, value) in columns + if value <: Number + continue + + else + pop!(columns, key) + end + end + + println(columns) cols = collect(keys(columns)) col_pairs = [(x, y) for x in cols, y in cols] col_len = length(cols) @@ -25,12 +30,38 @@ function splom(df::DataFrame) for (i, (x, y)) in enumerate(col_pairs) scatter_plots[i] = plot() + if x == y - plot!(; - xaxis=nothing, yaxis=nothing, showaxis=false - ) + if i == col_len^2 && group != "" + fakedat = (1:length(df[!, x])) .* 0 + scatter!( + fakedat, + fakedat; + xaxis=nothing, + yaxis=nothing, + showaxis=false, + grid=false, + xlims=(1, 0), + ylims=(1, 0), + group=df[!, group], + legend=:topleft, + ) + else + plot!(; + xaxis=nothing, + yaxis=nothing, + showaxis=false, + ) + end else - scatter!(df[!, x], df[!, y]) + scatter!( + df[!, x], + df[!, y]; + group=df.Species, + label="", + ) + + # end end if i % col_len == 1 @@ -42,11 +73,11 @@ function splom(df::DataFrame) end end - return plot(scatter_plots...; label="") + return plot(scatter_plots...;) end df = dataset("datasets", "iris") -splom(df) +splom(df; group=:Species,columns=[:SepalLength, :SepalWidth, :PetalWidth]) end # module